move pi-mono into companion-cloud as apps/companion-os

- Copy all pi-mono source into apps/companion-os/
- Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases
- Update deploy-staging.yml to build pi from source (bun compile) before Docker build
- Add apps/companion-os/** to path triggers
- No more cross-repo dispatch needed

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Harivansh Rathi 2026-03-07 09:22:50 -08:00
commit 0250f72976
579 changed files with 206942 additions and 0 deletions

262
packages/agent/CHANGELOG.md Normal file
View file

@ -0,0 +1,262 @@
# Changelog
## [Unreleased]
## [0.56.2] - 2026-03-05
## [0.56.1] - 2026-03-05
## [0.56.0] - 2026-03-04
## [0.55.4] - 2026-03-02
## [0.55.3] - 2026-02-27
## [0.55.2] - 2026-02-27
## [0.55.1] - 2026-02-26
## [0.55.0] - 2026-02-24
## [0.54.2] - 2026-02-23
## [0.54.1] - 2026-02-22
## [0.54.0] - 2026-02-19
## [0.53.1] - 2026-02-19
## [0.53.0] - 2026-02-17
## [0.52.12] - 2026-02-13
### Added
- Added `transport` to `AgentOptions` and `AgentLoopConfig` forwarding, allowing stream transport preference (`"sse"`, `"websocket"`, `"auto"`) to flow into provider calls.
## [0.52.11] - 2026-02-13
## [0.52.10] - 2026-02-12
## [0.52.9] - 2026-02-08
## [0.52.8] - 2026-02-07
## [0.52.7] - 2026-02-06
### Fixed
- Fixed `continue()` to resume queued steering/follow-up messages when context currently ends in an assistant message, and preserved one-at-a-time steering ordering during assistant-tail resumes ([#1312](https://github.com/badlogic/pi-mono/pull/1312) by [@ferologics](https://github.com/ferologics))
## [0.52.6] - 2026-02-05
## [0.52.5] - 2026-02-05
## [0.52.4] - 2026-02-05
## [0.52.3] - 2026-02-05
## [0.52.2] - 2026-02-05
## [0.52.1] - 2026-02-05
## [0.52.0] - 2026-02-05
## [0.51.6] - 2026-02-04
## [0.51.5] - 2026-02-04
## [0.51.4] - 2026-02-03
## [0.51.3] - 2026-02-03
## [0.51.2] - 2026-02-03
## [0.51.1] - 2026-02-02
## [0.51.0] - 2026-02-01
## [0.50.9] - 2026-02-01
## [0.50.8] - 2026-02-01
### Added
- Added `maxRetryDelayMs` option to `AgentOptions` to cap server-requested retry delays. Passed through to the underlying stream function. ([#1123](https://github.com/badlogic/pi-mono/issues/1123))
## [0.50.7] - 2026-01-31
## [0.50.6] - 2026-01-30
## [0.50.5] - 2026-01-30
## [0.50.3] - 2026-01-29
## [0.50.2] - 2026-01-29
## [0.50.1] - 2026-01-26
## [0.50.0] - 2026-01-26
## [0.49.3] - 2026-01-22
## [0.49.2] - 2026-01-19
## [0.49.1] - 2026-01-18
## [0.49.0] - 2026-01-17
## [0.48.0] - 2026-01-16
## [0.47.0] - 2026-01-16
## [0.46.0] - 2026-01-15
## [0.45.7] - 2026-01-13
## [0.45.6] - 2026-01-13
## [0.45.5] - 2026-01-13
## [0.45.4] - 2026-01-13
## [0.45.3] - 2026-01-13
## [0.45.2] - 2026-01-13
## [0.45.1] - 2026-01-13
## [0.45.0] - 2026-01-13
## [0.44.0] - 2026-01-12
## [0.43.0] - 2026-01-11
## [0.42.5] - 2026-01-11
## [0.42.4] - 2026-01-10
## [0.42.3] - 2026-01-10
## [0.42.2] - 2026-01-10
## [0.42.1] - 2026-01-09
## [0.42.0] - 2026-01-09
## [0.41.0] - 2026-01-09
## [0.40.1] - 2026-01-09
## [0.40.0] - 2026-01-08
## [0.39.1] - 2026-01-08
## [0.39.0] - 2026-01-08
## [0.38.0] - 2026-01-08
### Added
- `thinkingBudgets` option on `Agent` and `AgentOptions` to customize token budgets per thinking level ([#529](https://github.com/badlogic/pi-mono/pull/529) by [@melihmucuk](https://github.com/melihmucuk))
## [0.37.8] - 2026-01-07
## [0.37.7] - 2026-01-07
## [0.37.6] - 2026-01-06
## [0.37.5] - 2026-01-06
## [0.37.4] - 2026-01-06
## [0.37.3] - 2026-01-06
### Added
- `sessionId` option on `Agent` to forward session identifiers to LLM providers for session-based caching.
## [0.37.2] - 2026-01-05
## [0.37.1] - 2026-01-05
## [0.37.0] - 2026-01-05
### Fixed
- `minimal` thinking level now maps to `minimal` reasoning effort instead of being treated as `low`.
## [0.36.0] - 2026-01-05
## [0.35.0] - 2026-01-05
## [0.34.2] - 2026-01-04
## [0.34.1] - 2026-01-04
## [0.34.0] - 2026-01-04
## [0.33.0] - 2026-01-04
## [0.32.3] - 2026-01-03
## [0.32.2] - 2026-01-03
## [0.32.1] - 2026-01-03
## [0.32.0] - 2026-01-03
### Breaking Changes
- **Queue API replaced with steer/followUp**: The `queueMessage()` method has been split into two methods with different delivery semantics ([#403](https://github.com/badlogic/pi-mono/issues/403)):
- `steer(msg)`: Interrupts the agent mid-run. Delivered after current tool execution, skips remaining tools.
- `followUp(msg)`: Waits until the agent finishes. Delivered only when there are no more tool calls or steering messages.
- **Queue mode renamed**: `queueMode` option renamed to `steeringMode`. Added new `followUpMode` option. Both control whether messages are delivered one-at-a-time or all at once.
- **AgentLoopConfig callbacks renamed**: `getQueuedMessages` split into `getSteeringMessages` and `getFollowUpMessages`.
- **Agent methods renamed**:
- `queueMessage()``steer()` and `followUp()`
- `clearMessageQueue()``clearSteeringQueue()`, `clearFollowUpQueue()`, `clearAllQueues()`
- `setQueueMode()`/`getQueueMode()``setSteeringMode()`/`getSteeringMode()` and `setFollowUpMode()`/`getFollowUpMode()`
### Fixed
- `prompt()` and `continue()` now throw if called while the agent is already streaming, preventing race conditions and corrupted state. Use `steer()` or `followUp()` to queue messages during streaming, or `await` the previous call.
## [0.31.1] - 2026-01-02
## [0.31.0] - 2026-01-02
### Breaking Changes
- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, and `AgentTransport` interface have been removed. Use the `streamFn` option directly for custom streaming implementations.
- **Agent options renamed**:
- `transport` → removed (use `streamFn` instead)
- `messageTransformer``convertToLlm`
- `preprocessor``transformContext`
- **`AppMessage` renamed to `AgentMessage`**: All references to `AppMessage` have been renamed to `AgentMessage` for consistency.
- **`CustomMessages` renamed to `CustomAgentMessages`**: The declaration merging interface has been renamed.
- **`UserMessageWithAttachments` and `Attachment` types removed**: Attachment handling is now the responsibility of the `convertToLlm` function.
- **Agent loop moved from `@mariozechner/pi-ai`**: The `agentLoop`, `agentLoopContinue`, and related types have moved to this package. Import from `@mariozechner/pi-agent-core` instead.
### Added
- `streamFn` option on `Agent` for custom stream implementations. Default uses `streamSimple` from pi-ai.
- `streamProxy()` utility function for browser apps that need to proxy LLM calls through a backend server. Replaces the removed `AppTransport`.
- `getApiKey` option for dynamic API key resolution (useful for expiring OAuth tokens like GitHub Copilot).
- `agentLoop()` and `agentLoopContinue()` low-level functions for running the agent loop without the `Agent` class wrapper.
- New exported types: `AgentLoopConfig`, `AgentContext`, `AgentTool`, `AgentToolResult`, `AgentToolUpdateCallback`, `StreamFn`.
### Changed
- `Agent` constructor now has all options optional (empty options use defaults).
- `queueMessage()` is now synchronous (no longer returns a Promise).

426
packages/agent/README.md Normal file
View file

@ -0,0 +1,426 @@
# @mariozechner/pi-agent-core
Stateful agent with tool execution and event streaming. Built on `@mariozechner/pi-ai`.
## Installation
```bash
npm install @mariozechner/pi-agent-core
```
## Quick Start
```typescript
import { Agent } from "@mariozechner/pi-agent-core";
import { getModel } from "@mariozechner/pi-ai";
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model: getModel("anthropic", "claude-sonnet-4-20250514"),
},
});
agent.subscribe((event) => {
if (
event.type === "message_update" &&
event.assistantMessageEvent.type === "text_delta"
) {
// Stream just the new text chunk
process.stdout.write(event.assistantMessageEvent.delta);
}
});
await agent.prompt("Hello!");
```
## Core Concepts
### AgentMessage vs LLM Message
The agent works with `AgentMessage`, a flexible type that can include:
- Standard LLM messages (`user`, `assistant`, `toolResult`)
- Custom app-specific message types via declaration merging
LLMs only understand `user`, `assistant`, and `toolResult`. The `convertToLlm` function bridges this gap by filtering and transforming messages before each LLM call.
### Message Flow
```
AgentMessage[] → transformContext() → AgentMessage[] → convertToLlm() → Message[] → LLM
(optional) (required)
```
1. **transformContext**: Prune old messages, inject external context
2. **convertToLlm**: Filter out UI-only messages, convert custom types to LLM format
## Event Flow
The agent emits events for UI updates. Understanding the event sequence helps build responsive interfaces.
### prompt() Event Sequence
When you call `prompt("Hello")`:
```
prompt("Hello")
├─ agent_start
├─ turn_start
├─ message_start { message: userMessage } // Your prompt
├─ message_end { message: userMessage }
├─ message_start { message: assistantMessage } // LLM starts responding
├─ message_update { message: partial... } // Streaming chunks
├─ message_update { message: partial... }
├─ message_end { message: assistantMessage } // Complete response
├─ turn_end { message, toolResults: [] }
└─ agent_end { messages: [...] }
```
### With Tool Calls
If the assistant calls tools, the loop continues:
```
prompt("Read config.json")
├─ agent_start
├─ turn_start
├─ message_start/end { userMessage }
├─ message_start { assistantMessage with toolCall }
├─ message_update...
├─ message_end { assistantMessage }
├─ tool_execution_start { toolCallId, toolName, args }
├─ tool_execution_update { partialResult } // If tool streams
├─ tool_execution_end { toolCallId, result }
├─ message_start/end { toolResultMessage }
├─ turn_end { message, toolResults: [toolResult] }
├─ turn_start // Next turn
├─ message_start { assistantMessage } // LLM responds to tool result
├─ message_update...
├─ message_end
├─ turn_end
└─ agent_end
```
### continue() Event Sequence
`continue()` resumes from existing context without adding a new message. Use it for retries after errors.
```typescript
// After an error, retry from current state
await agent.continue();
```
The last message in context must be `user` or `toolResult` (not `assistant`).
### Event Types
| Event | Description |
| ----------------------- | --------------------------------------------------------------- |
| `agent_start` | Agent begins processing |
| `agent_end` | Agent completes with all new messages |
| `turn_start` | New turn begins (one LLM call + tool executions) |
| `turn_end` | Turn completes with assistant message and tool results |
| `message_start` | Any message begins (user, assistant, toolResult) |
| `message_update` | **Assistant only.** Includes `assistantMessageEvent` with delta |
| `message_end` | Message completes |
| `tool_execution_start` | Tool begins |
| `tool_execution_update` | Tool streams progress |
| `tool_execution_end` | Tool completes |
## Agent Options
```typescript
const agent = new Agent({
// Initial state
initialState: {
systemPrompt: string,
model: Model<any>,
thinkingLevel: "off" | "minimal" | "low" | "medium" | "high" | "xhigh",
tools: AgentTool<any>[],
messages: AgentMessage[],
},
// Convert AgentMessage[] to LLM Message[] (required for custom message types)
convertToLlm: (messages) => messages.filter(...),
// Transform context before convertToLlm (for pruning, compaction)
transformContext: async (messages, signal) => pruneOldMessages(messages),
// Steering mode: "one-at-a-time" (default) or "all"
steeringMode: "one-at-a-time",
// Follow-up mode: "one-at-a-time" (default) or "all"
followUpMode: "one-at-a-time",
// Custom stream function (for proxy backends)
streamFn: streamProxy,
// Session ID for provider caching
sessionId: "session-123",
// Dynamic API key resolution (for expiring OAuth tokens)
getApiKey: async (provider) => refreshToken(),
// Custom thinking budgets for token-based providers
thinkingBudgets: {
minimal: 128,
low: 512,
medium: 1024,
high: 2048,
},
});
```
## Agent State
```typescript
interface AgentState {
systemPrompt: string;
model: Model<any>;
thinkingLevel: ThinkingLevel;
tools: AgentTool<any>[];
messages: AgentMessage[];
isStreaming: boolean;
streamMessage: AgentMessage | null; // Current partial during streaming
pendingToolCalls: Set<string>;
error?: string;
}
```
Access via `agent.state`. During streaming, `streamMessage` contains the partial assistant message.
## Methods
### Prompting
```typescript
// Text prompt
await agent.prompt("Hello");
// With images
await agent.prompt("What's in this image?", [
{ type: "image", data: base64Data, mimeType: "image/jpeg" },
]);
// AgentMessage directly
await agent.prompt({ role: "user", content: "Hello", timestamp: Date.now() });
// Continue from current context (last message must be user or toolResult)
await agent.continue();
```
### State Management
```typescript
agent.setSystemPrompt("New prompt");
agent.setModel(getModel("openai", "gpt-4o"));
agent.setThinkingLevel("medium");
agent.setTools([myTool]);
agent.replaceMessages(newMessages);
agent.appendMessage(message);
agent.clearMessages();
agent.reset(); // Clear everything
```
### Session and Thinking Budgets
```typescript
agent.sessionId = "session-123";
agent.thinkingBudgets = {
minimal: 128,
low: 512,
medium: 1024,
high: 2048,
};
```
### Control
```typescript
agent.abort(); // Cancel current operation
await agent.waitForIdle(); // Wait for completion
```
### Events
```typescript
const unsubscribe = agent.subscribe((event) => {
console.log(event.type);
});
unsubscribe();
```
## Steering and Follow-up
Steering messages let you interrupt the agent while tools are running. Follow-up messages let you queue work after the agent would otherwise stop.
```typescript
agent.setSteeringMode("one-at-a-time");
agent.setFollowUpMode("one-at-a-time");
// While agent is running tools
agent.steer({
role: "user",
content: "Stop! Do this instead.",
timestamp: Date.now(),
});
// After the agent finishes its current work
agent.followUp({
role: "user",
content: "Also summarize the result.",
timestamp: Date.now(),
});
const steeringMode = agent.getSteeringMode();
const followUpMode = agent.getFollowUpMode();
agent.clearSteeringQueue();
agent.clearFollowUpQueue();
agent.clearAllQueues();
```
Use clearSteeringQueue, clearFollowUpQueue, or clearAllQueues to drop queued messages.
When steering messages are detected after a tool completes:
1. Remaining tools are skipped with error results
2. Steering messages are injected
3. LLM responds to the interruption
Follow-up messages are checked only when there are no more tool calls and no steering messages. If any are queued, they are injected and another turn runs.
## Custom Message Types
Extend `AgentMessage` via declaration merging:
```typescript
declare module "@mariozechner/pi-agent-core" {
interface CustomAgentMessages {
notification: { role: "notification"; text: string; timestamp: number };
}
}
// Now valid
const msg: AgentMessage = {
role: "notification",
text: "Info",
timestamp: Date.now(),
};
```
Handle custom types in `convertToLlm`:
```typescript
const agent = new Agent({
convertToLlm: (messages) =>
messages.flatMap((m) => {
if (m.role === "notification") return []; // Filter out
return [m];
}),
});
```
## Tools
Define tools using `AgentTool`:
```typescript
import { Type } from "@sinclair/typebox";
const readFileTool: AgentTool = {
name: "read_file",
label: "Read File", // For UI display
description: "Read a file's contents",
parameters: Type.Object({
path: Type.String({ description: "File path" }),
}),
execute: async (toolCallId, params, signal, onUpdate) => {
const content = await fs.readFile(params.path, "utf-8");
// Optional: stream progress
onUpdate?.({
content: [{ type: "text", text: "Reading..." }],
details: {},
});
return {
content: [{ type: "text", text: content }],
details: { path: params.path, size: content.length },
};
},
};
agent.setTools([readFileTool]);
```
### Error Handling
**Throw an error** when a tool fails. Do not return error messages as content.
```typescript
execute: async (toolCallId, params, signal, onUpdate) => {
if (!fs.existsSync(params.path)) {
throw new Error(`File not found: ${params.path}`);
}
// Return content only on success
return { content: [{ type: "text", text: "..." }] };
};
```
Thrown errors are caught by the agent and reported to the LLM as tool errors with `isError: true`.
## Proxy Usage
For browser apps that proxy through a backend:
```typescript
import { Agent, streamProxy } from "@mariozechner/pi-agent-core";
const agent = new Agent({
streamFn: (model, context, options) =>
streamProxy(model, context, {
...options,
authToken: "...",
proxyUrl: "https://your-server.com",
}),
});
```
## Low-Level API
For direct control without the Agent class:
```typescript
import { agentLoop, agentLoopContinue } from "@mariozechner/pi-agent-core";
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = {
model: getModel("openai", "gpt-4o"),
convertToLlm: (msgs) =>
msgs.filter((m) => ["user", "assistant", "toolResult"].includes(m.role)),
};
const userMessage = { role: "user", content: "Hello", timestamp: Date.now() };
for await (const event of agentLoop([userMessage], context, config)) {
console.log(event.type);
}
// Continue from existing context
for await (const event of agentLoopContinue(context, config)) {
console.log(event.type);
}
```
## License
MIT

View file

@ -0,0 +1,44 @@
{
"name": "@mariozechner/pi-agent-core",
"version": "0.56.2",
"description": "General-purpose agent with transport abstraction, state management, and attachment support",
"type": "module",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"files": [
"dist",
"README.md"
],
"scripts": {
"clean": "shx rm -rf dist",
"build": "tsgo -p tsconfig.build.json",
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
"test": "vitest --run",
"prepublishOnly": "npm run clean && npm run build"
},
"dependencies": {
"@mariozechner/pi-ai": "^0.56.2"
},
"keywords": [
"ai",
"agent",
"llm",
"transport",
"state-management"
],
"author": "Mario Zechner",
"license": "MIT",
"repository": {
"type": "git",
"url": "git+https://github.com/getcompanion-ai/co-mono.git",
"directory": "packages/agent"
},
"engines": {
"node": ">=20.0.0"
},
"devDependencies": {
"@types/node": "^24.3.0",
"typescript": "^5.7.3",
"vitest": "^3.2.4"
}
}

View file

@ -0,0 +1,452 @@
/**
* Agent loop that works with AgentMessage throughout.
* Transforms to Message[] only at the LLM call boundary.
*/
import {
type AssistantMessage,
type Context,
EventStream,
streamSimple,
type ToolResultMessage,
validateToolArguments,
} from "@mariozechner/pi-ai";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentTool,
AgentToolResult,
StreamFn,
} from "./types.js";
/**
* Start an agent loop with a new prompt message.
* The prompt is added to the context and events are emitted for it.
*/
export function agentLoop(
prompts: AgentMessage[],
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
const stream = createAgentStream();
(async () => {
const newMessages: AgentMessage[] = [...prompts];
const currentContext: AgentContext = {
...context,
messages: [...context.messages, ...prompts],
};
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
for (const prompt of prompts) {
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
}
await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
})();
return stream;
}
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retries - context already has user message or tool results.
*
* **Important:** The last message in context must convert to a `user` or `toolResult` message
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
* This cannot be validated here since `convertToLlm` is only called once per turn.
*/
export function agentLoopContinue(
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
if (context.messages.length === 0) {
throw new Error("Cannot continue: no messages in context");
}
if (context.messages[context.messages.length - 1].role === "assistant") {
throw new Error("Cannot continue from message role: assistant");
}
const stream = createAgentStream();
(async () => {
const newMessages: AgentMessage[] = [];
const currentContext: AgentContext = { ...context };
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
})();
return stream;
}
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
return new EventStream<AgentEvent, AgentMessage[]>(
(event: AgentEvent) => event.type === "agent_end",
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
);
}
/**
* Main loop logic shared by agentLoop and agentLoopContinue.
*/
async function runLoop(
currentContext: AgentContext,
newMessages: AgentMessage[],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<void> {
let firstTurn = true;
// Check for steering messages at start (user may have typed while waiting)
let pendingMessages: AgentMessage[] =
(await config.getSteeringMessages?.()) || [];
// Outer loop: continues when queued follow-up messages arrive after agent would stop
while (true) {
let hasMoreToolCalls = true;
let steeringAfterTools: AgentMessage[] | null = null;
// Inner loop: process tool calls and steering messages
while (hasMoreToolCalls || pendingMessages.length > 0) {
if (!firstTurn) {
stream.push({ type: "turn_start" });
} else {
firstTurn = false;
}
// Process pending messages (inject before next assistant response)
if (pendingMessages.length > 0) {
for (const message of pendingMessages) {
stream.push({ type: "message_start", message });
stream.push({ type: "message_end", message });
currentContext.messages.push(message);
newMessages.push(message);
}
pendingMessages = [];
}
// Stream assistant response
const message = await streamAssistantResponse(
currentContext,
config,
signal,
stream,
streamFn,
);
newMessages.push(message);
if (message.stopReason === "error" || message.stopReason === "aborted") {
stream.push({ type: "turn_end", message, toolResults: [] });
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
return;
}
// Check for tool calls
const toolCalls = message.content.filter((c) => c.type === "toolCall");
hasMoreToolCalls = toolCalls.length > 0;
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
const toolExecution = await executeToolCalls(
currentContext.tools,
message,
signal,
stream,
config.getSteeringMessages,
);
toolResults.push(...toolExecution.toolResults);
steeringAfterTools = toolExecution.steeringMessages ?? null;
for (const result of toolResults) {
currentContext.messages.push(result);
newMessages.push(result);
}
}
stream.push({ type: "turn_end", message, toolResults });
// Get steering messages after turn completes
if (steeringAfterTools && steeringAfterTools.length > 0) {
pendingMessages = steeringAfterTools;
steeringAfterTools = null;
} else {
pendingMessages = (await config.getSteeringMessages?.()) || [];
}
}
// Agent would stop here. Check for follow-up messages.
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
if (followUpMessages.length > 0) {
// Set as pending so inner loop processes them
pendingMessages = followUpMessages;
continue;
}
// No more messages, exit
break;
}
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
}
/**
* Stream an assistant response from the LLM.
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
*/
async function streamAssistantResponse(
context: AgentContext,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<AssistantMessage> {
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
let messages = context.messages;
if (config.transformContext) {
messages = await config.transformContext(messages, signal);
}
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
const llmMessages = await config.convertToLlm(messages);
// Build LLM context
const llmContext: Context = {
systemPrompt: context.systemPrompt,
messages: llmMessages,
tools: context.tools,
};
const streamFunction = streamFn || streamSimple;
// Resolve API key (important for expiring tokens)
const resolvedApiKey =
(config.getApiKey
? await config.getApiKey(config.model.provider)
: undefined) || config.apiKey;
const response = await streamFunction(config.model, llmContext, {
...config,
apiKey: resolvedApiKey,
signal,
});
let partialMessage: AssistantMessage | null = null;
let addedPartial = false;
for await (const event of response) {
switch (event.type) {
case "start":
partialMessage = event.partial;
context.messages.push(partialMessage);
addedPartial = true;
stream.push({ type: "message_start", message: { ...partialMessage } });
break;
case "text_start":
case "text_delta":
case "text_end":
case "thinking_start":
case "thinking_delta":
case "thinking_end":
case "toolcall_start":
case "toolcall_delta":
case "toolcall_end":
if (partialMessage) {
partialMessage = event.partial;
context.messages[context.messages.length - 1] = partialMessage;
stream.push({
type: "message_update",
assistantMessageEvent: event,
message: { ...partialMessage },
});
}
break;
case "done":
case "error": {
const finalMessage = await response.result();
if (addedPartial) {
context.messages[context.messages.length - 1] = finalMessage;
} else {
context.messages.push(finalMessage);
}
if (!addedPartial) {
stream.push({ type: "message_start", message: { ...finalMessage } });
}
stream.push({ type: "message_end", message: finalMessage });
return finalMessage;
}
}
}
return await response.result();
}
/**
* Execute tool calls from an assistant message.
*/
async function executeToolCalls(
tools: AgentTool<any>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
): Promise<{
toolResults: ToolResultMessage[];
steeringMessages?: AgentMessage[];
}> {
const toolCalls = assistantMessage.content.filter(
(c) => c.type === "toolCall",
);
const results: ToolResultMessage[] = [];
let steeringMessages: AgentMessage[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
const tool = tools?.find((t) => t.name === toolCall.name);
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
let result: AgentToolResult<any>;
let isError = false;
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
const validatedArgs = validateToolArguments(tool, toolCall);
result = await tool.execute(
toolCall.id,
validatedArgs,
signal,
(partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
},
);
} catch (e) {
result = {
content: [
{ type: "text", text: e instanceof Error ? e.message : String(e) },
],
details: {},
};
isError = true;
}
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError,
});
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
isError,
timestamp: Date.now(),
};
results.push(toolResultMessage);
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
// Check for steering messages - skip remaining tools if user interrupted
if (getSteeringMessages) {
const steering = await getSteeringMessages();
if (steering.length > 0) {
steeringMessages = steering;
const remainingCalls = toolCalls.slice(index + 1);
for (const skipped of remainingCalls) {
results.push(skipToolCall(skipped, stream));
}
break;
}
}
}
return { toolResults: results, steeringMessages };
}
function skipToolCall(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
stream: EventStream<AgentEvent, AgentMessage[]>,
): ToolResultMessage {
const result: AgentToolResult<any> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {},
};
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError: true,
});
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: {},
isError: true,
timestamp: Date.now(),
};
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
return toolResultMessage;
}

605
packages/agent/src/agent.ts Normal file
View file

@ -0,0 +1,605 @@
/**
* Agent class that uses the agent-loop directly.
* No transport abstraction - calls streamSimple via the loop.
*/
import {
getModel,
type ImageContent,
type Message,
type Model,
streamSimple,
type TextContent,
type ThinkingBudgets,
type Transport,
} from "@mariozechner/pi-ai";
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentState,
AgentTool,
StreamFn,
ThinkingLevel,
} from "./types.js";
/**
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
*/
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
return messages.filter(
(m) =>
m.role === "user" || m.role === "assistant" || m.role === "toolResult",
);
}
export interface AgentOptions {
initialState?: Partial<AgentState>;
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
* Default filters to user/assistant/toolResult and converts attachments.
*/
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to context before convertToLlm.
* Use for context pruning, injecting external context, etc.
*/
transformContext?: (
messages: AgentMessage[],
signal?: AbortSignal,
) => Promise<AgentMessage[]>;
/**
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
*/
steeringMode?: "all" | "one-at-a-time";
/**
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
*/
followUpMode?: "all" | "one-at-a-time";
/**
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
*/
streamFn?: StreamFn;
/**
* Optional session identifier forwarded to LLM providers.
* Used by providers that support session-based caching (e.g., OpenAI Codex).
*/
sessionId?: string;
/**
* Resolves an API key dynamically for each LLM call.
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
*/
getApiKey?: (
provider: string,
) => Promise<string | undefined> | string | undefined;
/**
* Custom token budgets for thinking levels (token-based providers only).
*/
thinkingBudgets?: ThinkingBudgets;
/**
* Preferred transport for providers that support multiple transports.
*/
transport?: Transport;
/**
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
* If the server's requested delay exceeds this value, the request fails immediately,
* allowing higher-level retry logic to handle it with user visibility.
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
*/
maxRetryDelayMs?: number;
}
export class Agent {
private _state: AgentState = {
systemPrompt: "",
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
thinkingLevel: "off",
tools: [],
messages: [],
isStreaming: false,
streamMessage: null,
pendingToolCalls: new Set<string>(),
error: undefined,
};
private listeners = new Set<(e: AgentEvent) => void>();
private abortController?: AbortController;
private convertToLlm: (
messages: AgentMessage[],
) => Message[] | Promise<Message[]>;
private transformContext?: (
messages: AgentMessage[],
signal?: AbortSignal,
) => Promise<AgentMessage[]>;
private steeringQueue: AgentMessage[] = [];
private followUpQueue: AgentMessage[] = [];
private steeringMode: "all" | "one-at-a-time";
private followUpMode: "all" | "one-at-a-time";
public streamFn: StreamFn;
private _sessionId?: string;
public getApiKey?: (
provider: string,
) => Promise<string | undefined> | string | undefined;
private runningPrompt?: Promise<void>;
private resolveRunningPrompt?: () => void;
private _thinkingBudgets?: ThinkingBudgets;
private _transport: Transport;
private _maxRetryDelayMs?: number;
constructor(opts: AgentOptions = {}) {
this._state = { ...this._state, ...opts.initialState };
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
this.transformContext = opts.transformContext;
this.steeringMode = opts.steeringMode || "one-at-a-time";
this.followUpMode = opts.followUpMode || "one-at-a-time";
this.streamFn = opts.streamFn || streamSimple;
this._sessionId = opts.sessionId;
this.getApiKey = opts.getApiKey;
this._thinkingBudgets = opts.thinkingBudgets;
this._transport = opts.transport ?? "sse";
this._maxRetryDelayMs = opts.maxRetryDelayMs;
}
/**
* Get the current session ID used for provider caching.
*/
get sessionId(): string | undefined {
return this._sessionId;
}
/**
* Set the session ID for provider caching.
* Call this when switching sessions (new session, branch, resume).
*/
set sessionId(value: string | undefined) {
this._sessionId = value;
}
/**
* Get the current thinking budgets.
*/
get thinkingBudgets(): ThinkingBudgets | undefined {
return this._thinkingBudgets;
}
/**
* Set custom thinking budgets for token-based providers.
*/
set thinkingBudgets(value: ThinkingBudgets | undefined) {
this._thinkingBudgets = value;
}
/**
* Get the current preferred transport.
*/
get transport(): Transport {
return this._transport;
}
/**
* Set the preferred transport.
*/
setTransport(value: Transport) {
this._transport = value;
}
/**
* Get the current max retry delay in milliseconds.
*/
get maxRetryDelayMs(): number | undefined {
return this._maxRetryDelayMs;
}
/**
* Set the maximum delay to wait for server-requested retries.
* Set to 0 to disable the cap.
*/
set maxRetryDelayMs(value: number | undefined) {
this._maxRetryDelayMs = value;
}
get state(): AgentState {
return this._state;
}
subscribe(fn: (e: AgentEvent) => void): () => void {
this.listeners.add(fn);
return () => this.listeners.delete(fn);
}
// State mutators
setSystemPrompt(v: string) {
this._state.systemPrompt = v;
}
setModel(m: Model<any>) {
this._state.model = m;
}
setThinkingLevel(l: ThinkingLevel) {
this._state.thinkingLevel = l;
}
setSteeringMode(mode: "all" | "one-at-a-time") {
this.steeringMode = mode;
}
getSteeringMode(): "all" | "one-at-a-time" {
return this.steeringMode;
}
setFollowUpMode(mode: "all" | "one-at-a-time") {
this.followUpMode = mode;
}
getFollowUpMode(): "all" | "one-at-a-time" {
return this.followUpMode;
}
setTools(t: AgentTool<any>[]) {
this._state.tools = t;
}
replaceMessages(ms: AgentMessage[]) {
this._state.messages = ms.slice();
}
appendMessage(m: AgentMessage) {
this._state.messages = [...this._state.messages, m];
}
/**
* Queue a steering message to interrupt the agent mid-run.
* Delivered after current tool execution, skips remaining tools.
*/
steer(m: AgentMessage) {
this.steeringQueue.push(m);
}
/**
* Queue a follow-up message to be processed after the agent finishes.
* Delivered only when agent has no more tool calls or steering messages.
*/
followUp(m: AgentMessage) {
this.followUpQueue.push(m);
}
clearSteeringQueue() {
this.steeringQueue = [];
}
clearFollowUpQueue() {
this.followUpQueue = [];
}
clearAllQueues() {
this.steeringQueue = [];
this.followUpQueue = [];
}
hasQueuedMessages(): boolean {
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
}
private dequeueSteeringMessages(): AgentMessage[] {
if (this.steeringMode === "one-at-a-time") {
if (this.steeringQueue.length > 0) {
const first = this.steeringQueue[0];
this.steeringQueue = this.steeringQueue.slice(1);
return [first];
}
return [];
}
const steering = this.steeringQueue.slice();
this.steeringQueue = [];
return steering;
}
private dequeueFollowUpMessages(): AgentMessage[] {
if (this.followUpMode === "one-at-a-time") {
if (this.followUpQueue.length > 0) {
const first = this.followUpQueue[0];
this.followUpQueue = this.followUpQueue.slice(1);
return [first];
}
return [];
}
const followUp = this.followUpQueue.slice();
this.followUpQueue = [];
return followUp;
}
clearMessages() {
this._state.messages = [];
}
abort() {
this.abortController?.abort();
}
waitForIdle(): Promise<void> {
return this.runningPrompt ?? Promise.resolve();
}
reset() {
this._state.messages = [];
this._state.isStreaming = false;
this._state.streamMessage = null;
this._state.pendingToolCalls = new Set<string>();
this._state.error = undefined;
this.steeringQueue = [];
this.followUpQueue = [];
}
/** Send a prompt with an AgentMessage */
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
async prompt(input: string, images?: ImageContent[]): Promise<void>;
async prompt(
input: string | AgentMessage | AgentMessage[],
images?: ImageContent[],
) {
if (this._state.isStreaming) {
throw new Error(
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
);
}
const model = this._state.model;
if (!model) throw new Error("No model configured");
let msgs: AgentMessage[];
if (Array.isArray(input)) {
msgs = input;
} else if (typeof input === "string") {
const content: Array<TextContent | ImageContent> = [
{ type: "text", text: input },
];
if (images && images.length > 0) {
content.push(...images);
}
msgs = [
{
role: "user",
content,
timestamp: Date.now(),
},
];
} else {
msgs = [input];
}
await this._runLoop(msgs);
}
/**
* Continue from current context (used for retries and resuming queued messages).
*/
async continue() {
if (this._state.isStreaming) {
throw new Error(
"Agent is already processing. Wait for completion before continuing.",
);
}
const messages = this._state.messages;
if (messages.length === 0) {
throw new Error("No messages to continue from");
}
if (messages[messages.length - 1].role === "assistant") {
const queuedSteering = this.dequeueSteeringMessages();
if (queuedSteering.length > 0) {
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
return;
}
const queuedFollowUp = this.dequeueFollowUpMessages();
if (queuedFollowUp.length > 0) {
await this._runLoop(queuedFollowUp);
return;
}
throw new Error("Cannot continue from message role: assistant");
}
await this._runLoop(undefined);
}
/**
* Run the agent loop.
* If messages are provided, starts a new conversation turn with those messages.
* Otherwise, continues from existing context.
*/
private async _runLoop(
messages?: AgentMessage[],
options?: { skipInitialSteeringPoll?: boolean },
) {
const model = this._state.model;
if (!model) throw new Error("No model configured");
this.runningPrompt = new Promise<void>((resolve) => {
this.resolveRunningPrompt = resolve;
});
this.abortController = new AbortController();
this._state.isStreaming = true;
this._state.streamMessage = null;
this._state.error = undefined;
const reasoning =
this._state.thinkingLevel === "off"
? undefined
: this._state.thinkingLevel;
const context: AgentContext = {
systemPrompt: this._state.systemPrompt,
messages: this._state.messages.slice(),
tools: this._state.tools,
};
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
const config: AgentLoopConfig = {
model,
reasoning,
sessionId: this._sessionId,
transport: this._transport,
thinkingBudgets: this._thinkingBudgets,
maxRetryDelayMs: this._maxRetryDelayMs,
convertToLlm: this.convertToLlm,
transformContext: this.transformContext,
getApiKey: this.getApiKey,
getSteeringMessages: async () => {
if (skipInitialSteeringPoll) {
skipInitialSteeringPoll = false;
return [];
}
return this.dequeueSteeringMessages();
},
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
};
let partial: AgentMessage | null = null;
try {
const stream = messages
? agentLoop(
messages,
context,
config,
this.abortController.signal,
this.streamFn,
)
: agentLoopContinue(
context,
config,
this.abortController.signal,
this.streamFn,
);
for await (const event of stream) {
// Update internal state based on events
switch (event.type) {
case "message_start":
partial = event.message;
this._state.streamMessage = event.message;
break;
case "message_update":
partial = event.message;
this._state.streamMessage = event.message;
break;
case "message_end":
partial = null;
this._state.streamMessage = null;
this.appendMessage(event.message);
break;
case "tool_execution_start": {
const s = new Set(this._state.pendingToolCalls);
s.add(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "tool_execution_end": {
const s = new Set(this._state.pendingToolCalls);
s.delete(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "turn_end":
if (
event.message.role === "assistant" &&
(event.message as any).errorMessage
) {
this._state.error = (event.message as any).errorMessage;
}
break;
case "agent_end":
this._state.isStreaming = false;
this._state.streamMessage = null;
break;
}
// Emit to listeners
this.emit(event);
}
// Handle any remaining partial message
if (
partial &&
partial.role === "assistant" &&
partial.content.length > 0
) {
const onlyEmpty = !partial.content.some(
(c) =>
(c.type === "thinking" && c.thinking.trim().length > 0) ||
(c.type === "text" && c.text.trim().length > 0) ||
(c.type === "toolCall" && c.name.trim().length > 0),
);
if (!onlyEmpty) {
this.appendMessage(partial);
} else {
if (this.abortController?.signal.aborted) {
throw new Error("Request was aborted");
}
}
}
} catch (err: any) {
const errorMsg: AgentMessage = {
role: "assistant",
content: [{ type: "text", text: "" }],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
errorMessage: err?.message || String(err),
timestamp: Date.now(),
} as AgentMessage;
this.appendMessage(errorMsg);
this._state.error = err?.message || String(err);
this.emit({ type: "agent_end", messages: [errorMsg] });
} finally {
this._state.isStreaming = false;
this._state.streamMessage = null;
this._state.pendingToolCalls = new Set<string>();
this.abortController = undefined;
this.resolveRunningPrompt?.();
this.runningPrompt = undefined;
this.resolveRunningPrompt = undefined;
}
}
private emit(e: AgentEvent) {
for (const listener of this.listeners) {
listener(e);
}
}
}

View file

@ -0,0 +1,8 @@
// Core Agent
export * from "./agent.js";
// Loop functions
export * from "./agent-loop.js";
// Proxy utilities
export * from "./proxy.js";
// Types
export * from "./types.js";

369
packages/agent/src/proxy.ts Normal file
View file

@ -0,0 +1,369 @@
/**
* Proxy stream function for apps that route LLM calls through a server.
* The server manages auth and proxies requests to LLM providers.
*/
// Internal import for JSON parsing utility
import {
type AssistantMessage,
type AssistantMessageEvent,
type Context,
EventStream,
type Model,
parseStreamingJson,
type SimpleStreamOptions,
type StopReason,
type ToolCall,
} from "@mariozechner/pi-ai";
// Create stream class matching ProxyMessageEventStream
class ProxyMessageEventStream extends EventStream<
AssistantMessageEvent,
AssistantMessage
> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
/**
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
*/
export type ProxyAssistantMessageEvent =
| { type: "start" }
| { type: "text_start"; contentIndex: number }
| { type: "text_delta"; contentIndex: number; delta: string }
| { type: "text_end"; contentIndex: number; contentSignature?: string }
| { type: "thinking_start"; contentIndex: number }
| { type: "thinking_delta"; contentIndex: number; delta: string }
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
| {
type: "toolcall_start";
contentIndex: number;
id: string;
toolName: string;
}
| { type: "toolcall_delta"; contentIndex: number; delta: string }
| { type: "toolcall_end"; contentIndex: number }
| {
type: "done";
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
usage: AssistantMessage["usage"];
}
| {
type: "error";
reason: Extract<StopReason, "aborted" | "error">;
errorMessage?: string;
usage: AssistantMessage["usage"];
};
export interface ProxyStreamOptions extends SimpleStreamOptions {
/** Auth token for the proxy server */
authToken: string;
/** Proxy server URL (e.g., "https://genai.example.com") */
proxyUrl: string;
}
/**
* Stream function that proxies through a server instead of calling LLM providers directly.
* The server strips the partial field from delta events to reduce bandwidth.
* We reconstruct the partial message client-side.
*
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
*
* @example
* ```typescript
* const agent = new Agent({
* streamFn: (model, context, options) =>
* streamProxy(model, context, {
* ...options,
* authToken: await getAuthToken(),
* proxyUrl: "https://genai.example.com",
* }),
* });
* ```
*/
export function streamProxy(
model: Model<any>,
context: Context,
options: ProxyStreamOptions,
): ProxyMessageEventStream {
const stream = new ProxyMessageEventStream();
(async () => {
// Initialize the partial message that we'll build up from events
const partial: AssistantMessage = {
role: "assistant",
stopReason: "stop",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
const abortHandler = () => {
if (reader) {
reader.cancel("Request aborted by user").catch(() => {});
}
};
if (options.signal) {
options.signal.addEventListener("abort", abortHandler);
}
try {
const response = await fetch(`${options.proxyUrl}/api/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${options.authToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
context,
options: {
temperature: options.temperature,
maxTokens: options.maxTokens,
reasoning: options.reasoning,
},
}),
signal: options.signal,
});
if (!response.ok) {
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
try {
const errorData = (await response.json()) as { error?: string };
if (errorData.error) {
errorMessage = `Proxy error: ${errorData.error}`;
}
} catch {
// Couldn't parse error response
}
throw new Error(errorMessage);
}
reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6).trim();
if (data) {
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
const event = processProxyEvent(proxyEvent, partial);
if (event) {
stream.push(event);
}
}
}
}
}
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
stream.end();
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
const reason = options.signal?.aborted ? "aborted" : "error";
partial.stopReason = reason;
partial.errorMessage = errorMessage;
stream.push({
type: "error",
reason,
error: partial,
});
stream.end();
} finally {
if (options.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
}
})();
return stream;
}
/**
* Process a proxy event and update the partial message.
*/
function processProxyEvent(
proxyEvent: ProxyAssistantMessageEvent,
partial: AssistantMessage,
): AssistantMessageEvent | undefined {
switch (proxyEvent.type) {
case "start":
return { type: "start", partial };
case "text_start":
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
return {
type: "text_start",
contentIndex: proxyEvent.contentIndex,
partial,
};
case "text_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.text += proxyEvent.delta;
return {
type: "text_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received text_delta for non-text content");
}
case "text_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.textSignature = proxyEvent.contentSignature;
return {
type: "text_end",
contentIndex: proxyEvent.contentIndex,
content: content.text,
partial,
};
}
throw new Error("Received text_end for non-text content");
}
case "thinking_start":
partial.content[proxyEvent.contentIndex] = {
type: "thinking",
thinking: "",
};
return {
type: "thinking_start",
contentIndex: proxyEvent.contentIndex,
partial,
};
case "thinking_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinking += proxyEvent.delta;
return {
type: "thinking_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received thinking_delta for non-thinking content");
}
case "thinking_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinkingSignature = proxyEvent.contentSignature;
return {
type: "thinking_end",
contentIndex: proxyEvent.contentIndex,
content: content.thinking,
partial,
};
}
throw new Error("Received thinking_end for non-thinking content");
}
case "toolcall_start":
partial.content[proxyEvent.contentIndex] = {
type: "toolCall",
id: proxyEvent.id,
name: proxyEvent.toolName,
arguments: {},
partialJson: "",
} satisfies ToolCall & { partialJson: string } as ToolCall;
return {
type: "toolcall_start",
contentIndex: proxyEvent.contentIndex,
partial,
};
case "toolcall_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
(content as any).partialJson += proxyEvent.delta;
content.arguments =
parseStreamingJson((content as any).partialJson) || {};
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
return {
type: "toolcall_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received toolcall_delta for non-toolCall content");
}
case "toolcall_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
delete (content as any).partialJson;
return {
type: "toolcall_end",
contentIndex: proxyEvent.contentIndex,
toolCall: content,
partial,
};
}
return undefined;
}
case "done":
partial.stopReason = proxyEvent.reason;
partial.usage = proxyEvent.usage;
return { type: "done", reason: proxyEvent.reason, message: partial };
case "error":
partial.stopReason = proxyEvent.reason;
partial.errorMessage = proxyEvent.errorMessage;
partial.usage = proxyEvent.usage;
return { type: "error", reason: proxyEvent.reason, error: partial };
default: {
const _exhaustiveCheck: never = proxyEvent;
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
return undefined;
}
}
}

237
packages/agent/src/types.ts Normal file
View file

@ -0,0 +1,237 @@
import type {
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
streamSimple,
TextContent,
Tool,
ToolResultMessage,
} from "@mariozechner/pi-ai";
import type { Static, TSchema } from "@sinclair/typebox";
/** Stream function - can return sync or Promise for async config lookup */
export type StreamFn = (
...args: Parameters<typeof streamSimple>
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
/**
* Configuration for the agent loop.
*/
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
*
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
* status messages) should be filtered out.
*
* @example
* ```typescript
* convertToLlm: (messages) => messages.flatMap(m => {
* if (m.role === "custom") {
* // Convert custom message to user message
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
* }
* if (m.role === "notification") {
* // Filter out UI-only messages
* return [];
* }
* // Pass through standard LLM messages
* return [m];
* })
* ```
*/
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to the context before `convertToLlm`.
*
* Use this for operations that work at the AgentMessage level:
* - Context window management (pruning old messages)
* - Injecting context from external sources
*
* @example
* ```typescript
* transformContext: async (messages) => {
* if (estimateTokens(messages) > MAX_TOKENS) {
* return pruneOldMessages(messages);
* }
* return messages;
* }
* ```
*/
transformContext?: (
messages: AgentMessage[],
signal?: AbortSignal,
) => Promise<AgentMessage[]>;
/**
* Resolves an API key dynamically for each LLM call.
*
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
* during long-running tool execution phases.
*/
getApiKey?: (
provider: string,
) => Promise<string | undefined> | string | undefined;
/**
* Returns steering messages to inject into the conversation mid-run.
*
* Called after each tool execution to check for user interruptions.
* If messages are returned, remaining tool calls are skipped and
* these messages are added to the context before the next LLM call.
*
* Use this for "steering" the agent while it's working.
*/
getSteeringMessages?: () => Promise<AgentMessage[]>;
/**
* Returns follow-up messages to process after the agent would otherwise stop.
*
* Called when the agent has no more tool calls and no steering messages.
* If messages are returned, they're added to the context and the agent
* continues with another turn.
*
* Use this for follow-up messages that should wait until the agent finishes.
*/
getFollowUpMessages?: () => Promise<AgentMessage[]>;
}
/**
* Thinking/reasoning level for models that support it.
* Note: "xhigh" is only supported by OpenAI gpt-5.1-codex-max, gpt-5.2, gpt-5.2-codex, gpt-5.3, and gpt-5.3-codex models.
*/
export type ThinkingLevel =
| "off"
| "minimal"
| "low"
| "medium"
| "high"
| "xhigh";
/**
* Extensible interface for custom app messages.
* Apps can extend via declaration merging:
*
* @example
* ```typescript
* declare module "@mariozechner/agent" {
* interface CustomAgentMessages {
* artifact: ArtifactMessage;
* notification: NotificationMessage;
* }
* }
* ```
*/
export interface CustomAgentMessages {
// Empty by default - apps extend via declaration merging
}
/**
* AgentMessage: Union of LLM messages + custom messages.
* This abstraction allows apps to add custom message types while maintaining
* type safety and compatibility with the base LLM messages.
*/
export type AgentMessage =
| Message
| CustomAgentMessages[keyof CustomAgentMessages];
/**
* Agent state containing all configuration and conversation data.
*/
export interface AgentState {
systemPrompt: string;
model: Model<any>;
thinkingLevel: ThinkingLevel;
tools: AgentTool<any>[];
messages: AgentMessage[]; // Can include attachments + custom message types
isStreaming: boolean;
streamMessage: AgentMessage | null;
pendingToolCalls: Set<string>;
error?: string;
}
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (
partialResult: AgentToolResult<T>,
) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<
TParameters extends TSchema = TSchema,
TDetails = any,
> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: AgentMessage[];
tools?: AgentTool<any>[];
}
/**
* Events emitted by the Agent for UI updates.
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
*/
export type AgentEvent =
// Agent lifecycle
| { type: "agent_start" }
| { type: "agent_end"; messages: AgentMessage[] }
// Turn lifecycle - a turn is one assistant response + any tool calls/results
| { type: "turn_start" }
| {
type: "turn_end";
message: AgentMessage;
toolResults: ToolResultMessage[];
}
// Message lifecycle - emitted for user, assistant, and toolResult messages
| { type: "message_start"; message: AgentMessage }
// Only emitted for assistant messages during streaming
| {
type: "message_update";
message: AgentMessage;
assistantMessageEvent: AssistantMessageEvent;
}
| { type: "message_end"; message: AgentMessage }
// Tool execution lifecycle
| {
type: "tool_execution_start";
toolCallId: string;
toolName: string;
args: any;
}
| {
type: "tool_execution_update";
toolCallId: string;
toolName: string;
args: any;
partialResult: any;
}
| {
type: "tool_execution_end";
toolCallId: string;
toolName: string;
result: any;
isError: boolean;
};

View file

@ -0,0 +1,629 @@
import {
type AssistantMessage,
type AssistantMessageEvent,
EventStream,
type Message,
type Model,
type UserMessage,
} from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop, agentLoopContinue } from "../src/agent-loop.js";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentTool,
} from "../src/types.js";
// Mock stream for testing - mimics MockAssistantStream
class MockAssistantStream extends EventStream<
AssistantMessageEvent,
AssistantMessage
> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
function createAssistantMessage(
content: AssistantMessage["content"],
stopReason: AssistantMessage["stopReason"] = "stop",
): AssistantMessage {
return {
role: "assistant",
content,
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason,
timestamp: Date.now(),
};
}
function createUserMessage(text: string): UserMessage {
return {
role: "user",
content: text,
timestamp: Date.now(),
};
}
// Simple identity converter for tests - just passes through standard messages
function identityConverter(messages: AgentMessage[]): Message[] {
return messages.filter(
(m) =>
m.role === "user" || m.role === "assistant" || m.role === "toolResult",
) as Message[];
}
describe("agentLoop with AgentMessage", () => {
it("should emit events with AgentMessage types", async () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("Hello");
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([
{ type: "text", text: "Hi there!" },
]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop(
[userPrompt],
context,
config,
undefined,
streamFn,
);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have user message and assistant message
expect(messages.length).toBe(2);
expect(messages[0].role).toBe("user");
expect(messages[1].role).toBe("assistant");
// Verify event sequence
const eventTypes = events.map((e) => e.type);
expect(eventTypes).toContain("agent_start");
expect(eventTypes).toContain("turn_start");
expect(eventTypes).toContain("message_start");
expect(eventTypes).toContain("message_end");
expect(eventTypes).toContain("turn_end");
expect(eventTypes).toContain("agent_end");
});
it("should handle custom message types via convertToLlm", async () => {
// Create a custom message type
interface CustomNotification {
role: "notification";
text: string;
timestamp: number;
}
const notification: CustomNotification = {
role: "notification",
text: "This is a notification",
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [notification as unknown as AgentMessage], // Custom message in context
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("Hello");
let convertedMessages: Message[] = [];
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: (messages) => {
// Filter out notifications, convert rest
convertedMessages = messages
.filter((m) => (m as { role: string }).role !== "notification")
.filter(
(m) =>
m.role === "user" ||
m.role === "assistant" ||
m.role === "toolResult",
) as Message[];
return convertedMessages;
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([
{ type: "text", text: "Response" },
]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop(
[userPrompt],
context,
config,
undefined,
streamFn,
);
for await (const event of stream) {
events.push(event);
}
// The notification should have been filtered out in convertToLlm
expect(convertedMessages.length).toBe(1); // Only user message
expect(convertedMessages[0].role).toBe("user");
});
it("should apply transformContext before convertToLlm", async () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [
createUserMessage("old message 1"),
createAssistantMessage([{ type: "text", text: "old response 1" }]),
createUserMessage("old message 2"),
createAssistantMessage([{ type: "text", text: "old response 2" }]),
],
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("new message");
let transformedMessages: AgentMessage[] = [];
let convertedMessages: Message[] = [];
const config: AgentLoopConfig = {
model: createModel(),
transformContext: async (messages) => {
// Keep only last 2 messages (prune old ones)
transformedMessages = messages.slice(-2);
return transformedMessages;
},
convertToLlm: (messages) => {
convertedMessages = messages.filter(
(m) =>
m.role === "user" ||
m.role === "assistant" ||
m.role === "toolResult",
) as Message[];
return convertedMessages;
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([
{ type: "text", text: "Response" },
]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const stream = agentLoop(
[userPrompt],
context,
config,
undefined,
streamFn,
);
for await (const _ of stream) {
// consume
}
// transformContext should have been called first, keeping only last 2
expect(transformedMessages.length).toBe(2);
// Then convertToLlm receives the pruned messages
expect(convertedMessages.length).toBe(2);
});
it("should handle tool calls and results", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `echoed: ${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: AgentMessage = createUserMessage("echo something");
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
let callIndex = 0;
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
if (callIndex === 0) {
// First call: return tool call
const message = createAssistantMessage(
[
{
type: "toolCall",
id: "tool-1",
name: "echo",
arguments: { value: "hello" },
},
],
"toolUse",
);
stream.push({ type: "done", reason: "toolUse", message });
} else {
// Second call: return final response
const message = createAssistantMessage([
{ type: "text", text: "done" },
]);
stream.push({ type: "done", reason: "stop", message });
}
callIndex++;
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop(
[userPrompt],
context,
config,
undefined,
streamFn,
);
for await (const event of stream) {
events.push(event);
}
// Tool should have been executed
expect(executed).toEqual(["hello"]);
// Should have tool execution events
const toolStart = events.find((e) => e.type === "tool_execution_start");
const toolEnd = events.find((e) => e.type === "tool_execution_end");
expect(toolStart).toBeDefined();
expect(toolEnd).toBeDefined();
if (toolEnd?.type === "tool_execution_end") {
expect(toolEnd.isError).toBe(false);
}
});
it("should inject queued messages and skip remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: AgentMessage = createUserMessage("start");
const queuedUserMessage: AgentMessage = createUserMessage("interrupt");
let queuedDelivered = false;
let callIndex = 0;
let sawInterruptInContext = false;
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
getSteeringMessages: async () => {
// Return steering message after first tool executes
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return [queuedUserMessage];
}
return [];
},
};
const events: AgentEvent[] = [];
const stream = agentLoop(
[userPrompt],
context,
config,
undefined,
(_model, ctx, _options) => {
// Check if interrupt message is in context on second call
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) =>
m.role === "user" &&
typeof m.content === "string" &&
m.content === "interrupt",
);
}
const mockStream = new MockAssistantStream();
queueMicrotask(() => {
if (callIndex === 0) {
// First call: return two tool calls
const message = createAssistantMessage(
[
{
type: "toolCall",
id: "tool-1",
name: "echo",
arguments: { value: "first" },
},
{
type: "toolCall",
id: "tool-2",
name: "echo",
arguments: { value: "second" },
},
],
"toolUse",
);
mockStream.push({ type: "done", reason: "toolUse", message });
} else {
// Second call: return final response
const message = createAssistantMessage([
{ type: "text", text: "done" },
]);
mockStream.push({ type: "done", reason: "stop", message });
}
callIndex++;
});
return mockStream;
},
);
for await (const event of stream) {
events.push(event);
}
// Only first tool should have executed
expect(executed).toEqual(["first"]);
// Second tool should be skipped
const toolEnds = events.filter(
(e): e is Extract<AgentEvent, { type: "tool_execution_end" }> =>
e.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[0].isError).toBe(false);
expect(toolEnds[1].isError).toBe(true);
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain(
"Skipped due to queued user message",
);
}
// Queued message should appear in events
const queuedMessageEvent = events.find(
(e) =>
e.type === "message_start" &&
e.message.role === "user" &&
typeof e.message.content === "string" &&
e.message.content === "interrupt",
);
expect(queuedMessageEvent).toBeDefined();
// Interrupt message should be in context when second LLM call is made
expect(sawInterruptInContext).toBe(true);
});
});
describe("agentLoopContinue with AgentMessage", () => {
it("should throw when context has no messages", () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
expect(() => agentLoopContinue(context, config)).toThrow(
"Cannot continue: no messages in context",
);
});
it("should continue from existing context without emitting user message events", async () => {
const userMessage: AgentMessage = createUserMessage("Hello");
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [userMessage],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([
{ type: "text", text: "Response" },
]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should only return the new assistant message (not the existing user message)
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
// Should NOT have user message events (that's the key difference from agentLoop)
const messageEndEvents = events.filter((e) => e.type === "message_end");
expect(messageEndEvents.length).toBe(1);
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
});
it("should allow custom message types as last message (caller responsibility)", async () => {
// Custom message that will be converted to user message by convertToLlm
interface CustomMessage {
role: "custom";
text: string;
timestamp: number;
}
const customMessage: CustomMessage = {
role: "custom",
text: "Hook content",
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [customMessage as unknown as AgentMessage],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: (messages) => {
// Convert custom to user message
return messages
.map((m) => {
if ((m as any).role === "custom") {
return {
role: "user" as const,
content: (m as any).text,
timestamp: m.timestamp,
};
}
return m;
})
.filter(
(m) =>
m.role === "user" ||
m.role === "assistant" ||
m.role === "toolResult",
) as Message[];
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([
{ type: "text", text: "Response to custom message" },
]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
// Should not throw - the custom message will be converted to user message
const stream = agentLoopContinue(context, config, undefined, streamFn);
const events: AgentEvent[] = [];
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
});
});

View file

@ -0,0 +1,383 @@
import {
type AssistantMessage,
type AssistantMessageEvent,
EventStream,
getModel,
} from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent } from "../src/index.js";
// Mock stream that mimics AssistantMessageEventStream
class MockAssistantStream extends EventStream<
AssistantMessageEvent,
AssistantMessage
> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
function createAssistantMessage(text: string): AssistantMessage {
return {
role: "assistant",
content: [{ type: "text", text }],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
}
describe("Agent", () => {
it("should create an agent instance with default state", () => {
const agent = new Agent();
expect(agent.state).toBeDefined();
expect(agent.state.systemPrompt).toBe("");
expect(agent.state.model).toBeDefined();
expect(agent.state.thinkingLevel).toBe("off");
expect(agent.state.tools).toEqual([]);
expect(agent.state.messages).toEqual([]);
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.streamMessage).toBe(null);
expect(agent.state.pendingToolCalls).toEqual(new Set());
expect(agent.state.error).toBeUndefined();
});
it("should create an agent instance with custom initial state", () => {
const customModel = getModel("openai", "gpt-4o-mini");
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model: customModel,
thinkingLevel: "low",
},
});
expect(agent.state.systemPrompt).toBe("You are a helpful assistant.");
expect(agent.state.model).toBe(customModel);
expect(agent.state.thinkingLevel).toBe("low");
});
it("should subscribe to events", () => {
const agent = new Agent();
let eventCount = 0;
const unsubscribe = agent.subscribe((_event) => {
eventCount++;
});
// No initial event on subscribe
expect(eventCount).toBe(0);
// State mutators don't emit events
agent.setSystemPrompt("Test prompt");
expect(eventCount).toBe(0);
expect(agent.state.systemPrompt).toBe("Test prompt");
// Unsubscribe should work
unsubscribe();
agent.setSystemPrompt("Another prompt");
expect(eventCount).toBe(0); // Should not increase
});
it("should update state with mutators", () => {
const agent = new Agent();
// Test setSystemPrompt
agent.setSystemPrompt("Custom prompt");
expect(agent.state.systemPrompt).toBe("Custom prompt");
// Test setModel
const newModel = getModel("google", "gemini-2.5-flash");
agent.setModel(newModel);
expect(agent.state.model).toBe(newModel);
// Test setThinkingLevel
agent.setThinkingLevel("high");
expect(agent.state.thinkingLevel).toBe("high");
// Test setTools
const tools = [{ name: "test", description: "test tool" } as any];
agent.setTools(tools);
expect(agent.state.tools).toBe(tools);
// Test replaceMessages
const messages = [
{ role: "user" as const, content: "Hello", timestamp: Date.now() },
];
agent.replaceMessages(messages);
expect(agent.state.messages).toEqual(messages);
expect(agent.state.messages).not.toBe(messages); // Should be a copy
// Test appendMessage
const newMessage = {
role: "assistant" as const,
content: [{ type: "text" as const, text: "Hi" }],
};
agent.appendMessage(newMessage as any);
expect(agent.state.messages).toHaveLength(2);
expect(agent.state.messages[1]).toBe(newMessage);
// Test clearMessages
agent.clearMessages();
expect(agent.state.messages).toEqual([]);
});
it("should support steering message queue", async () => {
const agent = new Agent();
const message = {
role: "user" as const,
content: "Steering message",
timestamp: Date.now(),
};
agent.steer(message);
// The message is queued but not yet in state.messages
expect(agent.state.messages).not.toContainEqual(message);
});
it("should support follow-up message queue", async () => {
const agent = new Agent();
const message = {
role: "user" as const,
content: "Follow-up message",
timestamp: Date.now(),
};
agent.followUp(message);
// The message is queued but not yet in state.messages
expect(agent.state.messages).not.toContainEqual(message);
});
it("should handle abort controller", () => {
const agent = new Agent();
// Should not throw even if nothing is running
expect(() => agent.abort()).not.toThrow();
});
it("should throw when prompt() called while streaming", async () => {
let abortSignal: AbortSignal | undefined;
const agent = new Agent({
// Use a stream function that responds to abort
streamFn: (_model, _context, options) => {
abortSignal = options?.signal;
const stream = new MockAssistantStream();
queueMicrotask(() => {
stream.push({ type: "start", partial: createAssistantMessage("") });
// Check abort signal periodically
const checkAbort = () => {
if (abortSignal?.aborted) {
stream.push({
type: "error",
reason: "aborted",
error: createAssistantMessage("Aborted"),
});
} else {
setTimeout(checkAbort, 5);
}
};
checkAbort();
});
return stream;
},
});
// Start first prompt (don't await, it will block until abort)
const firstPrompt = agent.prompt("First message");
// Wait a tick for isStreaming to be set
await new Promise((resolve) => setTimeout(resolve, 10));
expect(agent.state.isStreaming).toBe(true);
// Second prompt should reject
await expect(agent.prompt("Second message")).rejects.toThrow(
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
);
// Cleanup - abort to stop the stream
agent.abort();
await firstPrompt.catch(() => {}); // Ignore abort error
});
it("should throw when continue() called while streaming", async () => {
let abortSignal: AbortSignal | undefined;
const agent = new Agent({
streamFn: (_model, _context, options) => {
abortSignal = options?.signal;
const stream = new MockAssistantStream();
queueMicrotask(() => {
stream.push({ type: "start", partial: createAssistantMessage("") });
const checkAbort = () => {
if (abortSignal?.aborted) {
stream.push({
type: "error",
reason: "aborted",
error: createAssistantMessage("Aborted"),
});
} else {
setTimeout(checkAbort, 5);
}
};
checkAbort();
});
return stream;
},
});
// Start first prompt
const firstPrompt = agent.prompt("First message");
await new Promise((resolve) => setTimeout(resolve, 10));
expect(agent.state.isStreaming).toBe(true);
// continue() should reject
await expect(agent.continue()).rejects.toThrow(
"Agent is already processing. Wait for completion before continuing.",
);
// Cleanup
agent.abort();
await firstPrompt.catch(() => {});
});
it("continue() should process queued follow-up messages after an assistant turn", async () => {
const agent = new Agent({
streamFn: () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
stream.push({
type: "done",
reason: "stop",
message: createAssistantMessage("Processed"),
});
});
return stream;
},
});
agent.replaceMessages([
{
role: "user",
content: [{ type: "text", text: "Initial" }],
timestamp: Date.now() - 10,
},
createAssistantMessage("Initial response"),
]);
agent.followUp({
role: "user",
content: [{ type: "text", text: "Queued follow-up" }],
timestamp: Date.now(),
});
await expect(agent.continue()).resolves.toBeUndefined();
const hasQueuedFollowUp = agent.state.messages.some((message) => {
if (message.role !== "user") return false;
if (typeof message.content === "string")
return message.content === "Queued follow-up";
return message.content.some(
(part) => part.type === "text" && part.text === "Queued follow-up",
);
});
expect(hasQueuedFollowUp).toBe(true);
expect(agent.state.messages[agent.state.messages.length - 1].role).toBe(
"assistant",
);
});
it("continue() should keep one-at-a-time steering semantics from assistant tail", async () => {
let responseCount = 0;
const agent = new Agent({
streamFn: () => {
const stream = new MockAssistantStream();
responseCount++;
queueMicrotask(() => {
stream.push({
type: "done",
reason: "stop",
message: createAssistantMessage(`Processed ${responseCount}`),
});
});
return stream;
},
});
agent.replaceMessages([
{
role: "user",
content: [{ type: "text", text: "Initial" }],
timestamp: Date.now() - 10,
},
createAssistantMessage("Initial response"),
]);
agent.steer({
role: "user",
content: [{ type: "text", text: "Steering 1" }],
timestamp: Date.now(),
});
agent.steer({
role: "user",
content: [{ type: "text", text: "Steering 2" }],
timestamp: Date.now() + 1,
});
await expect(agent.continue()).resolves.toBeUndefined();
const recentMessages = agent.state.messages.slice(-4);
expect(recentMessages.map((m) => m.role)).toEqual([
"user",
"assistant",
"user",
"assistant",
]);
expect(responseCount).toBe(2);
});
it("forwards sessionId to streamFn options", async () => {
let receivedSessionId: string | undefined;
const agent = new Agent({
sessionId: "session-abc",
streamFn: (_model, _context, options) => {
receivedSessionId = options?.sessionId;
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage("ok");
stream.push({ type: "done", reason: "stop", message });
});
return stream;
},
});
await agent.prompt("hello");
expect(receivedSessionId).toBe("session-abc");
// Test setter
agent.sessionId = "session-def";
expect(agent.sessionId).toBe("session-def");
await agent.prompt("hello again");
expect(receivedSessionId).toBe("session-def");
});
});

View file

@ -0,0 +1,316 @@
/**
* A test suite to ensure Amazon Bedrock models work correctly with the agent loop.
*
* Some Bedrock models don't support all features (e.g., reasoning signatures).
* This test suite verifies that the agent loop works with various Bedrock models.
*
* This test suite is not enabled by default unless AWS credentials and
* `BEDROCK_EXTENSIVE_MODEL_TEST` environment variables are set.
*
* You can run this test suite with:
* ```bash
* $ AWS_REGION=us-east-1 BEDROCK_EXTENSIVE_MODEL_TEST=1 AWS_PROFILE=pi npm test -- ./test/bedrock-models.test.ts
* ```
*
* ## Known Issues by Category
*
* 1. **Inference Profile Required**: Some models require an inference profile ARN instead of on-demand.
* 2. **Invalid Model ID**: Model identifiers that don't exist in the current region.
* 3. **Max Tokens Exceeded**: Model's maxTokens in our config exceeds the actual limit.
* 4. **No Reasoning in User Messages**: Model rejects reasoning content when replayed in conversation.
* 5. **Invalid Signature Format**: Model validates signature format (Anthropic newer models).
*/
import type { AssistantMessage } from "@mariozechner/pi-ai";
import { getModels } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent } from "../src/index.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
// =============================================================================
// Known Issue Categories
// =============================================================================
/** Models that require inference profile ARN (not available on-demand in us-east-1) */
const REQUIRES_INFERENCE_PROFILE = new Set([
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-opus-20240229-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
]);
/** Models with invalid identifiers (not available in us-east-1 or don't exist) */
const INVALID_MODEL_ID = new Set([
"deepseek.v3-v1:0",
"eu.anthropic.claude-haiku-4-5-20251001-v1:0",
"eu.anthropic.claude-opus-4-5-20251101-v1:0",
"eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
"qwen.qwen3-235b-a22b-2507-v1:0",
"qwen.qwen3-coder-480b-a35b-v1:0",
]);
/** Models where our maxTokens config exceeds the model's actual limit */
const MAX_TOKENS_EXCEEDED = new Set([
"us.meta.llama4-maverick-17b-instruct-v1:0",
"us.meta.llama4-scout-17b-instruct-v1:0",
]);
/**
* Models that reject reasoning content in user messages (when replaying conversation).
* These work for multi-turn but fail when synthetic thinking is injected.
*/
const NO_REASONING_IN_USER_MESSAGES = new Set([
// Mistral models
"mistral.ministral-3-14b-instruct",
"mistral.ministral-3-8b-instruct",
"mistral.mistral-large-2402-v1:0",
"mistral.voxtral-mini-3b-2507",
"mistral.voxtral-small-24b-2507",
// Nvidia models
"nvidia.nemotron-nano-12b-v2",
"nvidia.nemotron-nano-9b-v2",
// Qwen models
"qwen.qwen3-coder-30b-a3b-v1:0",
// Amazon Nova models
"us.amazon.nova-lite-v1:0",
"us.amazon.nova-micro-v1:0",
"us.amazon.nova-premier-v1:0",
"us.amazon.nova-pro-v1:0",
// Meta Llama models
"us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0",
"us.meta.llama3-3-70b-instruct-v1:0",
// DeepSeek
"us.deepseek.r1-v1:0",
// Older Anthropic models
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
// Cohere models
"cohere.command-r-plus-v1:0",
"cohere.command-r-v1:0",
// Google models
"google.gemma-3-27b-it",
"google.gemma-3-4b-it",
// Non-Anthropic models that don't support signatures (now handled by omitting signature)
// but still reject reasoning content in user messages
"global.amazon.nova-2-lite-v1:0",
"minimax.minimax-m2",
"moonshot.kimi-k2-thinking",
"openai.gpt-oss-120b-1:0",
"openai.gpt-oss-20b-1:0",
"openai.gpt-oss-safeguard-120b",
"openai.gpt-oss-safeguard-20b",
"qwen.qwen3-32b-v1:0",
"qwen.qwen3-next-80b-a3b",
"qwen.qwen3-vl-235b-a22b",
]);
/**
* Models that validate signature format (Anthropic newer models).
* These work for multi-turn but fail when synthetic/invalid signature is injected.
*/
const VALIDATES_SIGNATURE_FORMAT = new Set([
"global.anthropic.claude-haiku-4-5-20251001-v1:0",
"global.anthropic.claude-opus-4-5-20251101-v1:0",
"global.anthropic.claude-sonnet-4-20250514-v1:0",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"us.anthropic.claude-opus-4-1-20250805-v1:0",
"us.anthropic.claude-opus-4-20250514-v1:0",
]);
/**
* DeepSeek R1 fails multi-turn because it rejects reasoning in the replayed assistant message.
*/
const REJECTS_REASONING_ON_REPLAY = new Set(["us.deepseek.r1-v1:0"]);
// =============================================================================
// Helper Functions
// =============================================================================
function isModelUnavailable(modelId: string): boolean {
return (
REQUIRES_INFERENCE_PROFILE.has(modelId) ||
INVALID_MODEL_ID.has(modelId) ||
MAX_TOKENS_EXCEEDED.has(modelId)
);
}
function failsMultiTurnWithThinking(modelId: string): boolean {
return REJECTS_REASONING_ON_REPLAY.has(modelId);
}
function failsSyntheticSignature(modelId: string): boolean {
return (
NO_REASONING_IN_USER_MESSAGES.has(modelId) ||
VALIDATES_SIGNATURE_FORMAT.has(modelId)
);
}
// =============================================================================
// Tests
// =============================================================================
describe("Amazon Bedrock Models - Agent Loop", () => {
const shouldRunExtensiveTests =
hasBedrockCredentials() && process.env.BEDROCK_EXTENSIVE_MODEL_TEST;
// Get all Amazon Bedrock models
const allBedrockModels = getModels("amazon-bedrock");
if (shouldRunExtensiveTests) {
for (const model of allBedrockModels) {
const modelId = model.id;
describe(`Model: ${modelId}`, () => {
// Skip entirely unavailable models
const unavailable = isModelUnavailable(modelId);
it.skipIf(unavailable)(
"should handle basic text prompt",
{ timeout: 60_000 },
async () => {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. Be extremely concise.",
model,
thinkingLevel: "off",
tools: [],
},
});
await agent.prompt("Reply with exactly: 'OK'");
if (agent.state.error) {
throw new Error(`Basic prompt error: ${agent.state.error}`);
}
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2);
const assistantMessage = agent.state.messages[1];
if (assistantMessage.role !== "assistant")
throw new Error("Expected assistant message");
console.log(`${modelId}: OK`);
},
);
// Skip if model is unavailable or known to fail multi-turn with thinking
const skipMultiTurn =
unavailable || failsMultiTurnWithThinking(modelId);
it.skipIf(skipMultiTurn)(
"should handle multi-turn conversation with thinking content in history",
{ timeout: 120_000 },
async () => {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. Be extremely concise.",
model,
thinkingLevel: "medium",
tools: [],
},
});
// First turn
await agent.prompt("My name is Alice.");
if (agent.state.error) {
throw new Error(`First turn error: ${agent.state.error}`);
}
// Second turn - this should replay the first assistant message which may contain thinking
await agent.prompt("What is my name?");
if (agent.state.error) {
throw new Error(`Second turn error: ${agent.state.error}`);
}
expect(agent.state.messages.length).toBe(4);
console.log(`${modelId}: multi-turn OK`);
},
);
// Skip if model is unavailable or known to fail synthetic signature
const skipSynthetic = unavailable || failsSyntheticSignature(modelId);
it.skipIf(skipSynthetic)(
"should handle conversation with synthetic thinking signature in history",
{ timeout: 60_000 },
async () => {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. Be extremely concise.",
model,
thinkingLevel: "off",
tools: [],
},
});
// Inject a message with a thinking block that has a signature
const syntheticAssistantMessage: AssistantMessage = {
role: "assistant",
content: [
{
type: "thinking",
thinking: "I need to remember the user's name.",
thinkingSignature: "synthetic-signature-123",
},
{ type: "text", text: "Nice to meet you, Alice!" },
],
api: "bedrock-converse-stream",
provider: "amazon-bedrock",
model: modelId,
usage: {
input: 10,
output: 20,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 30,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
stopReason: "stop",
timestamp: Date.now(),
};
agent.replaceMessages([
{
role: "user",
content: "My name is Alice.",
timestamp: Date.now(),
},
syntheticAssistantMessage,
]);
await agent.prompt("What is my name?");
if (agent.state.error) {
throw new Error(
`Synthetic signature error: ${agent.state.error}`,
);
}
expect(agent.state.messages.length).toBe(4);
console.log(`${modelId}: synthetic signature OK`);
},
);
});
}
} else {
it.skip("skipped - set AWS credentials and BEDROCK_EXTENSIVE_MODEL_TEST=1 to run", () => {});
}
});

View file

@ -0,0 +1,18 @@
/**
* Utility functions for Amazon Bedrock tests
*/
/**
* Check if any valid AWS credentials are configured for Bedrock.
* Returns true if any of the following are set:
* - AWS_PROFILE (named profile from ~/.aws/credentials)
* - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY (IAM keys)
* - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)
*/
export function hasBedrockCredentials(): boolean {
return !!(
process.env.AWS_PROFILE ||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
process.env.AWS_BEARER_TOKEN_BEDROCK
);
}

View file

@ -0,0 +1,571 @@
import type {
AssistantMessage,
Model,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent } from "../src/index.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { calculateTool } from "./utils/calculate.js";
async function basicPrompt(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant. Keep your responses concise.",
model,
thinkingLevel: "off",
tools: [],
},
});
await agent.prompt("What is 2+2? Answer with just the number.");
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2);
expect(agent.state.messages[0].role).toBe("user");
expect(agent.state.messages[1].role).toBe("assistant");
const assistantMessage = agent.state.messages[1];
if (assistantMessage.role !== "assistant")
throw new Error("Expected assistant message");
expect(assistantMessage.content.length).toBeGreaterThan(0);
const textContent = assistantMessage.content.find((c) => c.type === "text");
expect(textContent).toBeDefined();
if (textContent?.type !== "text") throw new Error("Expected text content");
expect(textContent.text).toContain("4");
}
async function toolExecution(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. Always use the calculator tool for math.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
});
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBeGreaterThanOrEqual(3);
const toolResultMsg = agent.state.messages.find(
(m) => m.role === "toolResult",
);
expect(toolResultMsg).toBeDefined();
if (toolResultMsg?.role !== "toolResult")
throw new Error("Expected tool result message");
const textContent =
toolResultMsg.content
?.filter((c) => c.type === "text")
.map((c: any) => c.text)
.join("\n") || "";
expect(textContent).toBeDefined();
const expectedResult = 123 * 456;
expect(textContent).toContain(String(expectedResult));
const finalMessage = agent.state.messages[agent.state.messages.length - 1];
if (finalMessage.role !== "assistant")
throw new Error("Expected final assistant message");
const finalText = finalMessage.content.find((c) => c.type === "text");
expect(finalText).toBeDefined();
if (finalText?.type !== "text") throw new Error("Expected text content");
// Check for number with or without comma formatting
const hasNumber =
finalText.text.includes(String(expectedResult)) ||
finalText.text.includes("56,088") ||
finalText.text.includes("56088");
expect(hasNumber).toBe(true);
}
async function abortExecution(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
});
const promptPromise = agent.prompt(
"Calculate 100 * 200, then 300 * 400, then sum the results.",
);
setTimeout(() => {
agent.abort();
}, 100);
await promptPromise;
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBeGreaterThanOrEqual(2);
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
if (lastMessage.role !== "assistant")
throw new Error("Expected assistant message");
expect(lastMessage.stopReason).toBe("aborted");
expect(lastMessage.errorMessage).toBeDefined();
expect(agent.state.error).toBeDefined();
expect(agent.state.error).toBe(lastMessage.errorMessage);
}
async function stateUpdates(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [],
},
});
const events: Array<string> = [];
agent.subscribe((event) => {
events.push(event.type);
});
await agent.prompt("Count from 1 to 5.");
// Should have received lifecycle events
expect(events).toContain("agent_start");
expect(events).toContain("agent_end");
expect(events).toContain("message_start");
expect(events).toContain("message_end");
// May have message_update events during streaming
const hasMessageUpdates = events.some((e) => e === "message_update");
expect(hasMessageUpdates).toBe(true);
// Check final state
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2); // User message + assistant response
}
async function multiTurnConversation(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [],
},
});
await agent.prompt("My name is Alice.");
expect(agent.state.messages.length).toBe(2);
await agent.prompt("What is my name?");
expect(agent.state.messages.length).toBe(4);
const lastMessage = agent.state.messages[3];
if (lastMessage.role !== "assistant")
throw new Error("Expected assistant message");
const lastText = lastMessage.content.find((c) => c.type === "text");
if (lastText?.type !== "text") throw new Error("Expected text content");
expect(lastText.text.toLowerCase()).toContain("alice");
}
describe("Agent E2E Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)(
"Google Provider (gemini-2.5-flash)",
() => {
const model = getModel("google", "gemini-2.5-flash");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Provider (gpt-4o-mini)",
() => {
const model = getModel("openai", "gpt-4o-mini");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
"Anthropic Provider (claude-haiku-4-5)",
() => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-3)", () => {
const model = getModel("xai", "grok-3");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.GROQ_API_KEY)(
"Groq Provider (openai/gpt-oss-20b)",
() => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
describe.skipIf(!process.env.CEREBRAS_API_KEY)(
"Cerebras Provider (gpt-oss-120b)",
() => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
describe.skipIf(!process.env.ZAI_API_KEY)(
"zAI Provider (glm-4.5-air)",
() => {
const model = getModel("zai", "glm-4.5-air");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
describe.skipIf(!hasBedrockCredentials())(
"Amazon Bedrock Provider (claude-sonnet-4-5)",
() => {
const model = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
},
);
});
describe("Agent.continue()", () => {
describe("validation", () => {
it("should throw when no messages in context", async () => {
const agent = new Agent({
initialState: {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
});
await expect(agent.continue()).rejects.toThrow(
"No messages to continue from",
);
});
it("should throw when last message is assistant", async () => {
const agent = new Agent({
initialState: {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
});
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "Hello" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
agent.replaceMessages([assistantMessage]);
await expect(agent.continue()).rejects.toThrow(
"Cannot continue from message role: assistant",
);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
"continue from user message",
() => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and get response when last message is user", async () => {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. Follow instructions exactly.",
model,
thinkingLevel: "off",
tools: [],
},
});
// Manually add a user message without calling prompt()
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
timestamp: Date.now(),
};
agent.replaceMessages([userMessage]);
// Continue from the user message
await agent.continue();
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2);
expect(agent.state.messages[0].role).toBe("user");
expect(agent.state.messages[1].role).toBe("assistant");
const assistantMsg = agent.state.messages[1] as AssistantMessage;
const textContent = assistantMsg.content.find((c) => c.type === "text");
expect(textContent).toBeDefined();
if (textContent?.type === "text") {
expect(textContent.text.toUpperCase()).toContain("HELLO WORLD");
}
});
},
);
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
"continue from tool result",
() => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and process tool results", async () => {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. After getting a calculation result, state the answer clearly.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
});
// Set up a conversation state as if tool was just executed
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "What is 5 + 3?" }],
timestamp: Date.now(),
};
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{ type: "text", text: "Let me calculate that." },
{
type: "toolCall",
id: "calc-1",
name: "calculate",
arguments: { expression: "5 + 3" },
},
],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
stopReason: "toolUse",
timestamp: Date.now(),
};
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: "calc-1",
toolName: "calculate",
content: [{ type: "text", text: "5 + 3 = 8" }],
isError: false,
timestamp: Date.now(),
};
agent.replaceMessages([userMessage, assistantMessage, toolResult]);
// Continue from the tool result
await agent.continue();
expect(agent.state.isStreaming).toBe(false);
// Should have added an assistant response
expect(agent.state.messages.length).toBeGreaterThanOrEqual(4);
const lastMessage =
agent.state.messages[agent.state.messages.length - 1];
expect(lastMessage.role).toBe("assistant");
if (lastMessage.role === "assistant") {
const textContent = lastMessage.content
.filter((c) => c.type === "text")
.map((c) => (c as { type: "text"; text: string }).text)
.join(" ");
// Should mention 8 in the response
expect(textContent).toMatch(/8/);
}
});
},
);
});

View file

@ -0,0 +1,37 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool, AgentToolResult } from "../../src/types.js";
export interface CalculateResult extends AgentToolResult<undefined> {
content: Array<{ type: "text"; text: string }>;
details: undefined;
}
export function calculate(expression: string): CalculateResult {
try {
const result = new Function(`return ${expression}`)();
return {
content: [{ type: "text", text: `${expression} = ${result}` }],
details: undefined,
};
} catch (e: any) {
throw new Error(e.message || String(e));
}
}
const calculateSchema = Type.Object({
expression: Type.String({
description: "The mathematical expression to evaluate",
}),
});
type CalculateParams = Static<typeof calculateSchema>;
export const calculateTool: AgentTool<typeof calculateSchema, undefined> = {
label: "Calculator",
name: "calculate",
description: "Evaluate mathematical expressions",
parameters: calculateSchema,
execute: async (_toolCallId: string, args: CalculateParams) => {
return calculate(args.expression);
},
};

View file

@ -0,0 +1,61 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool, AgentToolResult } from "../../src/types.js";
export interface GetCurrentTimeResult extends AgentToolResult<{
utcTimestamp: number;
}> {}
export async function getCurrentTime(
timezone?: string,
): Promise<GetCurrentTimeResult> {
const date = new Date();
if (timezone) {
try {
const timeStr = date.toLocaleString("en-US", {
timeZone: timezone,
dateStyle: "full",
timeStyle: "long",
});
return {
content: [{ type: "text", text: timeStr }],
details: { utcTimestamp: date.getTime() },
};
} catch (_e) {
throw new Error(
`Invalid timezone: ${timezone}. Current UTC time: ${date.toISOString()}`,
);
}
}
const timeStr = date.toLocaleString("en-US", {
dateStyle: "full",
timeStyle: "long",
});
return {
content: [{ type: "text", text: timeStr }],
details: { utcTimestamp: date.getTime() },
};
}
const getCurrentTimeSchema = Type.Object({
timezone: Type.Optional(
Type.String({
description:
"Optional timezone (e.g., 'America/New_York', 'Europe/London')",
}),
),
});
type GetCurrentTimeParams = Static<typeof getCurrentTimeSchema>;
export const getCurrentTimeTool: AgentTool<
typeof getCurrentTimeSchema,
{ utcTimestamp: number }
> = {
label: "Current Time",
name: "get_current_time",
description: "Get the current date and time",
parameters: getCurrentTimeSchema,
execute: async (_toolCallId: string, args: GetCurrentTimeParams) => {
return getCurrentTime(args.timezone);
},
};

View file

@ -0,0 +1,9 @@
{
"extends": "../../tsconfig.base.json",
"compilerOptions": {
"outDir": "./dist",
"rootDir": "./src"
},
"include": ["src/**/*.ts"],
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
}

View file

@ -0,0 +1,9 @@
import { defineConfig } from "vitest/config";
export default defineConfig({
test: {
globals: true,
environment: "node",
testTimeout: 30000, // 30 seconds for API calls
},
});

787
packages/ai/CHANGELOG.md Normal file
View file

@ -0,0 +1,787 @@
# Changelog
## [Unreleased]
## [0.56.2] - 2026-03-05
### Added
- Added `gpt-5.4` model support for `openai`, `openai-codex`, `azure-openai-responses`, and `opencode` providers, with GPT-5.4 treated as xhigh-capable and capped to a 272000 context window in built-in metadata.
- Added `gpt-5.3-codex` fallback model availability for `github-copilot` until upstream model catalogs include it ([#1853](https://github.com/badlogic/pi-mono/issues/1853)).
### Fixed
- Preserved OpenAI Responses assistant `phase` metadata (`commentary`, `final_answer`) across turns by encoding `id` and `phase` in `textSignature` for session persistence and replay, with backward compatibility for legacy plain signatures ([#1819](https://github.com/badlogic/pi-mono/issues/1819)).
- Fixed OpenAI Responses replay to omit empty thinking blocks, avoiding invalid no-op reasoning items in follow-up turns.
- Switched the Mistral provider from the OpenAI-compatible completions path to Mistral's native SDK and conversations API, preserving native thinking blocks and Mistral-specific message semantics across turns ([#1716](https://github.com/badlogic/pi-mono/issues/1716)).
- Fixed Antigravity endpoint fallback: 403/404 responses now cascade to the next endpoint instead of throwing immediately, added `autopush-cloudcode-pa.sandbox` endpoint to the fallback list, and removed extra fingerprint headers (`X-Goog-Api-Client`, `Client-Metadata`) from Antigravity requests ([#1830](https://github.com/badlogic/pi-mono/issues/1830)).
- Fixed `@mariozechner/pi-ai/oauth` package exports to point directly at built `dist` files, avoiding broken TypeScript resolution through unpublished wrapper targets ([#1856](https://github.com/badlogic/pi-mono/issues/1856)).
- Fixed Gemini 3 unsigned tool call replay: use `skip_thought_signature_validator` sentinel instead of converting function calls to text, preserving structured tool call context across multi-turn conversations ([#1829](https://github.com/badlogic/pi-mono/issues/1829)).
## [0.56.1] - 2026-03-05
## [0.56.0] - 2026-03-04
### Breaking Changes
- Moved Node OAuth runtime exports off the top-level package entry. Import OAuth login/refresh functions from `@mariozechner/pi-ai/oauth` instead of `@mariozechner/pi-ai` ([#1814](https://github.com/badlogic/pi-mono/issues/1814))
### Added
- Added `gemini-3.1-flash-lite-preview` fallback model entry for the `google` provider so it remains selectable until upstream model catalogs include it ([#1785](https://github.com/badlogic/pi-mono/issues/1785), thanks [@n-WN](https://github.com/n-WN)).
- Added OpenCode Go provider support with `opencode-go` model catalog entries and `OPENCODE_API_KEY` environment variable support ([#1757](https://github.com/badlogic/pi-mono/issues/1757)).
### Changed
- Updated Antigravity Gemini 3.1 model metadata and request headers to match current upstream behavior.
### Fixed
- Fixed Gemini 3.1 thinking-level detection in `google` and `google-vertex` providers so `gemini-3.1-*` models use Gemini 3 level-based thinking config instead of budget fallback ([#1785](https://github.com/badlogic/pi-mono/issues/1785), thanks [@n-WN](https://github.com/n-WN)).
- Fixed browser bundling failures by lazy-loading the Bedrock provider and removing Node-only side effects from the default browser import graph ([#1814](https://github.com/badlogic/pi-mono/issues/1814)).
- Fixed `ERR_VM_DYNAMIC_IMPORT_CALLBACK_MISSING` failures by replacing `Function`-based dynamic imports with module dynamic imports in browser-safe provider loading paths ([#1814](https://github.com/badlogic/pi-mono/issues/1814)).
- Fixed Bedrock region resolution for `AWS_PROFILE` by honoring `region` from the selected profile when present ([#1800](https://github.com/badlogic/pi-mono/issues/1800)).
- Fixed Groq Qwen3 reasoning effort mapping by translating unsupported effort values to provider-supported values ([#1745](https://github.com/badlogic/pi-mono/issues/1745)).
## [0.55.4] - 2026-03-02
## [0.55.3] - 2026-02-27
## [0.55.2] - 2026-02-27
### Fixed
- Restored built-in OAuth providers when unregistering dynamically registered provider IDs and added `resetOAuthProviders()` for registry reset flows.
- Fixed Z.ai thinking control using wrong parameter name (`thinking` instead of `enable_thinking`), causing thinking to always be enabled and wasting tokens/latency ([#1674](https://github.com/badlogic/pi-mono/pull/1674) by [@okuyam2y](https://github.com/okuyam2y))
- Fixed `redacted_thinking` blocks being silently dropped during Anthropic streaming. They are now captured as `ThinkingContent` with `redacted: true`, passed back to the API in multi-turn conversations, and handled in cross-model message transformation ([#1665](https://github.com/badlogic/pi-mono/pull/1665) by [@tctev](https://github.com/tctev))
- Fixed `interleaved-thinking-2025-05-14` beta header being sent for adaptive thinking models (Opus 4.6, Sonnet 4.6) where the header is deprecated or redundant ([#1665](https://github.com/badlogic/pi-mono/pull/1665) by [@tctev](https://github.com/tctev))
- Fixed temperature being sent alongside extended thinking, which is incompatible with both adaptive and budget-based thinking modes ([#1665](https://github.com/badlogic/pi-mono/pull/1665) by [@tctev](https://github.com/tctev))
- Fixed `(external, cli)` user-agent flag causing 401 errors on Anthropic setup-token endpoint ([#1677](https://github.com/badlogic/pi-mono/pull/1677) by [@LazerLance777](https://github.com/LazerLance777))
- Fixed crash when OpenAI-compatible provider returns a chunk with no `choices` array by adding optional chaining ([#1671](https://github.com/badlogic/pi-mono/issues/1671))
## [0.55.1] - 2026-02-26
### Added
- Added `gemini-3.1-pro-preview` model support to the `google-gemini-cli` provider ([#1599](https://github.com/badlogic/pi-mono/pull/1599) by [@audichuang](https://github.com/audichuang))
### Fixed
- Fixed adaptive thinking for Claude Sonnet 4.6 in Anthropic and Bedrock providers, and clamped unsupported `xhigh` effort values to supported levels ([#1548](https://github.com/badlogic/pi-mono/pull/1548) by [@tctev](https://github.com/tctev))
- Fixed Vertex ADC credential detection race by avoiding caching a false negative during async import initialization ([#1550](https://github.com/badlogic/pi-mono/pull/1550) by [@jeremiahgaylord-web](https://github.com/jeremiahgaylord-web))
## [0.55.0] - 2026-02-24
## [0.54.2] - 2026-02-23
## [0.54.1] - 2026-02-22
## [0.54.0] - 2026-02-19
## [0.53.1] - 2026-02-19
## [0.53.0] - 2026-02-17
### Added
- Added Anthropic `claude-sonnet-4-6` fallback model entry to generated model definitions.
## [0.52.12] - 2026-02-13
### Added
- Added `transport` to `StreamOptions` with values `"sse"`, `"websocket"`, and `"auto"` (currently supported by `openai-codex-responses`).
- Added WebSocket transport support for OpenAI Codex Responses (`openai-codex-responses`).
### Changed
- OpenAI Codex Responses now defaults to SSE transport unless `transport` is explicitly set.
- OpenAI Codex Responses WebSocket connections are cached per `sessionId` and expire after 5 minutes of inactivity.
## [0.52.11] - 2026-02-13
### Added
- Added MiniMax M2.5 model entries for `minimax`, `minimax-cn`, `openrouter`, and `vercel-ai-gateway` providers, plus `minimax-m2.5-free` for `opencode`.
## [0.52.10] - 2026-02-12
### Added
- Added optional `metadata` field to `StreamOptions` for passing provider-specific metadata (e.g. Anthropic `user_id` for abuse tracking/rate limiting) ([#1384](https://github.com/badlogic/pi-mono/pull/1384) by [@7Sageer](https://github.com/7Sageer))
- Added `gpt-5.3-codex-spark` model definition for OpenAI and OpenAI Codex providers (128k context, text-only, research preview). Not yet functional, may become available in the next few hours or days.
### Changed
- Routed GitHub Copilot Claude 4.x models through Anthropic Messages API, centralized Copilot dynamic header handling, and added Copilot Claude Anthropic stream coverage ([#1353](https://github.com/badlogic/pi-mono/pull/1353) by [@NateSmyth](https://github.com/NateSmyth))
### Fixed
- Fixed OpenAI completions and responses streams to tolerate malformed trailing tool-call JSON without failing parsing ([#1424](https://github.com/badlogic/pi-mono/issues/1424))
## [0.52.9] - 2026-02-08
### Changed
- Updated the Antigravity system instruction to a more compact version for Google Gemini CLI compatibility
### Fixed
- Use `parametersJsonSchema` for Google provider tool declarations to support full JSON Schema (anyOf, oneOf, const, etc.) ([#1398](https://github.com/badlogic/pi-mono/issues/1398) by [@jarib](https://github.com/jarib))
- Reverted incorrect Antigravity model change: `claude-opus-4-6-thinking` back to `claude-opus-4-5-thinking` (model doesn't exist on Antigravity endpoint)
- Corrected opencode context windows for Claude Sonnet 4 and 4.5 ([#1383](https://github.com/badlogic/pi-mono/issues/1383))
## [0.52.8] - 2026-02-07
### Added
- Added OpenRouter `auto` model alias for automatic model routing ([#1361](https://github.com/badlogic/pi-mono/pull/1361) by [@yogasanas](https://github.com/yogasanas))
### Changed
- Replaced Claude Opus 4.5 with Opus 4.6 in model definitions ([#1345](https://github.com/badlogic/pi-mono/pull/1345) by [@calvin-hpnet](https://github.com/calvin-hpnet))
## [0.52.7] - 2026-02-06
### Added
- Added `AWS_BEDROCK_SKIP_AUTH` and `AWS_BEDROCK_FORCE_HTTP1` environment variables for connecting to unauthenticated Bedrock proxies ([#1320](https://github.com/badlogic/pi-mono/pull/1320) by [@virtuald](https://github.com/virtuald))
### Fixed
- Set OpenAI Responses API requests to `store: false` by default to avoid server-side history logging ([#1308](https://github.com/badlogic/pi-mono/issues/1308))
- Re-exported TypeBox `Type`, `Static`, and `TSchema` from `@mariozechner/pi-ai` to match documentation and avoid duplicate TypeBox type identity issues in pnpm setups ([#1338](https://github.com/badlogic/pi-mono/issues/1338))
- Fixed Bedrock adaptive thinking handling for Claude Opus 4.6 with interleaved thinking beta responses ([#1323](https://github.com/badlogic/pi-mono/pull/1323) by [@markusylisiurunen](https://github.com/markusylisiurunen))
- Fixed `AWS_BEDROCK_SKIP_AUTH` environment detection to avoid `process` access in non-Node.js environments
## [0.52.6] - 2026-02-05
## [0.52.5] - 2026-02-05
### Fixed
- Fixed `supportsXhigh()` to treat Anthropic Messages Opus 4.6 models as xhigh-capable so `streamSimple` can map `xhigh` to adaptive effort `max`
## [0.52.4] - 2026-02-05
## [0.52.3] - 2026-02-05
### Fixed
- Fixed Bedrock Opus 4.6 model IDs (removed `:0` suffix) and cache pricing for `us.*` and `eu.*` variants
- Added missing `eu.anthropic.claude-opus-4-6-v1` inference profile to model catalog
- Fixed Claude Opus 4.6 context window metadata to 200000 for Anthropic and OpenCode providers
## [0.52.2] - 2026-02-05
## [0.52.1] - 2026-02-05
### Added
- Added adaptive thinking support for Claude Opus 4.6 with effort levels (`low`, `medium`, `high`, `max`)
- Added `effort` option to `AnthropicOptions` for controlling adaptive thinking depth
- `thinkingEnabled` now automatically uses adaptive thinking for Opus 4.6+ models and budget-based thinking for older models
- `streamSimple`/`completeSimple` automatically map `ThinkingLevel` to effort levels for Opus 4.6
### Changed
- Updated `@anthropic-ai/sdk` to 0.73.0
- Updated `@aws-sdk/client-bedrock-runtime` to 3.983.0
- Updated `@google/genai` to 1.40.0
- Removed `fast-xml-parser` override (no longer needed)
## [0.52.0] - 2026-02-05
### Added
- Added Claude Opus 4.6 model to the generated model catalog
- Added GPT-5.3 Codex model to the generated model catalog (OpenAI Codex provider only)
## [0.51.6] - 2026-02-04
### Fixed
- Fixed OpenAI Codex Responses provider to respect configured baseUrl ([#1244](https://github.com/badlogic/pi-mono/issues/1244))
## [0.51.5] - 2026-02-04
### Changed
- Changed Bedrock model generation to drop legacy workarounds now handled upstream ([#1239](https://github.com/badlogic/pi-mono/pull/1239) by [@unexge](https://github.com/unexge))
## [0.51.4] - 2026-02-03
## [0.51.3] - 2026-02-03
### Fixed
- Fixed xhigh thinking level support check to accept gpt-5.2 model IDs ([#1209](https://github.com/badlogic/pi-mono/issues/1209))
## [0.51.2] - 2026-02-03
## [0.51.1] - 2026-02-02
### Fixed
- Fixed `cache_control` not being applied to string-format user messages in Anthropic provider
## [0.51.0] - 2026-02-01
### Fixed
- Fixed `cacheRetention` option not being passed through in `buildBaseOptions` ([#1154](https://github.com/badlogic/pi-mono/issues/1154))
- Fixed OAuth login/refresh not using HTTP proxy settings (`HTTP_PROXY`, `HTTPS_PROXY` env vars) ([#1132](https://github.com/badlogic/pi-mono/issues/1132))
- Fixed OpenAI-compatible completions to omit unsupported `strict` tool fields for providers that reject them ([#1172](https://github.com/badlogic/pi-mono/issues/1172))
## [0.50.9] - 2026-02-01
### Added
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129))
- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/badlogic/pi-mono/issues/1134))
## [0.50.8] - 2026-02-01
### Added
- Added `maxRetryDelayMs` option to `StreamOptions` to cap server-requested retry delays. When a provider (e.g., Google Gemini CLI) requests a delay longer than this value, the request fails immediately with an informative error instead of waiting silently. Default: 60000ms (60 seconds). Set to 0 to disable the cap. ([#1123](https://github.com/badlogic/pi-mono/issues/1123))
- Added Qwen thinking format support for OpenAI-compatible completions via `enable_thinking`. ([#940](https://github.com/badlogic/pi-mono/pull/940) by [@4h9fbZ](https://github.com/4h9fbZ))
## [0.50.7] - 2026-01-31
## [0.50.6] - 2026-01-30
## [0.50.5] - 2026-01-30
## [0.50.4] - 2026-01-30
### Added
- Added Vercel AI Gateway routing support via `vercelGatewayRouting` option in model config ([#1051](https://github.com/badlogic/pi-mono/pull/1051) by [@ben-vargas](https://github.com/ben-vargas))
### Fixed
- Updated Antigravity User-Agent from 1.11.5 to 1.15.8 to fix rejected requests ([#1079](https://github.com/badlogic/pi-mono/issues/1079))
- Fixed tool call argument defaults for Anthropic and Google history conversion when providers omit inputs ([#1065](https://github.com/badlogic/pi-mono/issues/1065))
## [0.50.3] - 2026-01-29
### Added
- Added Kimi For Coding provider support (Moonshot AI's Anthropic-compatible coding API)
## [0.50.2] - 2026-01-29
### Added
- Added Hugging Face provider support via OpenAI-compatible Inference Router ([#994](https://github.com/badlogic/pi-mono/issues/994))
- Added `PI_CACHE_RETENTION` environment variable to control cache TTL for Anthropic (5m vs 1h) and OpenAI (in-memory vs 24h). Set to `long` for extended retention. Only applies to direct API calls (api.anthropic.com, api.openai.com). ([#967](https://github.com/badlogic/pi-mono/issues/967))
### Fixed
- Fixed OpenAI completions `toolChoice` handling to correctly set `type: "function"` wrapper ([#998](https://github.com/badlogic/pi-mono/pull/998) by [@williamtwomey](https://github.com/williamtwomey))
- Fixed cross-provider handoff failing when switching from OpenAI Responses API providers (github-copilot, openai-codex) to other providers due to pipe-separated tool call IDs not being normalized, and trailing underscores in truncated IDs being rejected by OpenAI Codex ([#1022](https://github.com/badlogic/pi-mono/issues/1022))
- Fixed 429 rate limit errors incorrectly triggering auto-compaction instead of retry with backoff ([#1038](https://github.com/badlogic/pi-mono/issues/1038))
- Fixed Anthropic provider to handle `sensitive` stop_reason returned by API ([#978](https://github.com/badlogic/pi-mono/issues/978))
- Fixed DeepSeek API compatibility by detecting `deepseek.com` URLs and disabling unsupported `developer` role ([#1048](https://github.com/badlogic/pi-mono/issues/1048))
- Fixed Anthropic provider to preserve input token counts when proxies omit them in `message_delta` events ([#1045](https://github.com/badlogic/pi-mono/issues/1045))
## [0.50.1] - 2026-01-26
### Fixed
- Fixed OpenCode Zen model generation to exclude deprecated models ([#970](https://github.com/badlogic/pi-mono/pull/970) by [@DanielTatarkin](https://github.com/DanielTatarkin))
## [0.50.0] - 2026-01-26
### Added
- Added OpenRouter provider routing support for custom models via `openRouterRouting` compat field ([#859](https://github.com/badlogic/pi-mono/pull/859) by [@v01dpr1mr0s3](https://github.com/v01dpr1mr0s3))
- Added `azure-openai-responses` provider support for Azure OpenAI Responses API. ([#890](https://github.com/badlogic/pi-mono/pull/890) by [@markusylisiurunen](https://github.com/markusylisiurunen))
- Added HTTP proxy environment variable support for API requests ([#942](https://github.com/badlogic/pi-mono/pull/942) by [@haoqixu](https://github.com/haoqixu))
- Added `createAssistantMessageEventStream()` factory function for use in extensions.
- Added `resetApiProviders()` to clear and re-register built-in API providers.
### Changed
- Refactored API streaming dispatch to use an API registry with provider-owned `streamSimple` mapping.
- Moved environment API key resolution to `env-api-keys.ts` and re-exported it from the package entrypoint.
- Azure OpenAI Responses provider now uses base URL configuration with deployment-aware model mapping and no longer includes service tier handling.
### Fixed
- Fixed Bun runtime detection for dynamic imports in browser-compatible modules (stream.ts, openai-codex-responses.ts, openai-codex.ts) ([#922](https://github.com/badlogic/pi-mono/pull/922) by [@dannote](https://github.com/dannote))
- Fixed streaming functions to use `model.api` instead of hardcoded API types
- Fixed Google providers to default tool call arguments to an empty object when omitted
- Fixed OpenAI Responses streaming to handle `arguments.done` events on OpenAI-compatible endpoints ([#917](https://github.com/badlogic/pi-mono/pull/917) by [@williballenthin](https://github.com/williballenthin))
- Fixed OpenAI Codex Responses tool strictness handling after the shared responses refactor
- Fixed Azure OpenAI Responses streaming to guard deltas before content parts and correct metadata and handoff gating
- Fixed OpenAI completions tool-result image batching after consecutive tool results ([#902](https://github.com/badlogic/pi-mono/pull/902) by [@terrorobe](https://github.com/terrorobe))
## [0.49.3] - 2026-01-22
### Added
- Added `headers` option to `StreamOptions` for custom HTTP headers in API requests. Supported by all providers except Amazon Bedrock (which uses AWS SDK auth). Headers are merged with provider defaults and `model.headers`, with `options.headers` taking precedence.
- Added `originator` option to `loginOpenAICodex()` for custom OAuth client identification
- Browser compatibility for pi-ai: replaced top-level Node.js imports with dynamic imports for browser environments ([#873](https://github.com/badlogic/pi-mono/issues/873))
### Fixed
- Fixed OpenAI Responses API 400 error "function_call without required reasoning item" when switching between models (same provider, different model). The fix omits the `id` field for function_calls from different models to avoid triggering OpenAI's reasoning/function_call pairing validation ([#886](https://github.com/badlogic/pi-mono/issues/886))
## [0.49.2] - 2026-01-19
### Added
- Added AWS credential detection for ECS/Kubernetes environments: `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`, `AWS_CONTAINER_CREDENTIALS_FULL_URI`, `AWS_WEB_IDENTITY_TOKEN_FILE` ([#848](https://github.com/badlogic/pi-mono/issues/848))
### Fixed
- Fixed OpenAI Responses 400 error "reasoning without following item" by skipping errored/aborted assistant messages entirely in transform-messages.ts ([#838](https://github.com/badlogic/pi-mono/pull/838))
### Removed
- Removed `strictResponsesPairing` compat option (no longer needed after the transform-messages fix)
## [0.49.1] - 2026-01-18
### Added
- Added `OpenAIResponsesCompat` interface with `strictResponsesPairing` option for Azure OpenAI Responses API, which requires strict reasoning/message pairing in history replay ([#768](https://github.com/badlogic/pi-mono/pull/768) by [@prateekmedia](https://github.com/prateekmedia))
### Changed
- Split `OpenAICompat` into `OpenAICompletionsCompat` and `OpenAIResponsesCompat` for type-safe API-specific compat settings
### Fixed
- Fixed tool call ID normalization for cross-provider handoffs (e.g., Codex to Antigravity Claude) ([#821](https://github.com/badlogic/pi-mono/issues/821))
## [0.49.0] - 2026-01-17
### Changed
- OpenAI Codex responses now use the context system prompt directly in the instructions field.
### Fixed
- Fixed orphaned tool results after errored assistant messages causing Codex API errors. When an assistant message has `stopReason: "error"`, its tool calls are now excluded from pending tool tracking, preventing synthetic tool results from being generated for calls that will be dropped by provider-specific converters. ([#812](https://github.com/badlogic/pi-mono/issues/812))
- Fixed Bedrock Claude max_tokens handling to always exceed thinking budget tokens, preventing compaction failures. ([#797](https://github.com/badlogic/pi-mono/pull/797) by [@pjtf93](https://github.com/pjtf93))
- Fixed Claude Code tool name normalization to match the Claude Code tool list case-insensitively and remove invalid mappings.
## [0.48.0] - 2026-01-16
### Fixed
- Fixed OpenAI-compatible provider feature detection to use `model.provider` in addition to URL, allowing custom base URLs (e.g., proxies) to work correctly with provider-specific settings ([#774](https://github.com/badlogic/pi-mono/issues/774))
- Fixed Gemini 3 context loss when switching from providers without thought signatures: unsigned tool calls are now converted to text with anti-mimicry notes instead of being skipped
- Fixed string numbers in tool arguments not being coerced to numbers during validation ([#786](https://github.com/badlogic/pi-mono/pull/786) by [@dannote](https://github.com/dannote))
- Fixed Bedrock tool call IDs to use only alphanumeric characters, avoiding API errors from invalid characters ([#781](https://github.com/badlogic/pi-mono/pull/781) by [@pjtf93](https://github.com/pjtf93))
- Fixed empty error assistant messages (from 429/500 errors) breaking the tool_use to tool_result chain by filtering them in `transformMessages`
## [0.47.0] - 2026-01-16
### Fixed
- Fixed OpenCode provider's `/v1` endpoint to use `system` role instead of `developer` role, fixing `400 Incorrect role information` error for models using `openai-completions` API ([#755](https://github.com/badlogic/pi-mono/pull/755) by [@melihmucuk](https://github.com/melihmucuk))
- Added retry logic to OpenAI Codex provider for transient errors (429, 5xx, connection failures). Uses exponential backoff with up to 3 retries. ([#733](https://github.com/badlogic/pi-mono/issues/733))
## [0.46.0] - 2026-01-15
### Added
- Added MiniMax China (`minimax-cn`) provider support ([#725](https://github.com/badlogic/pi-mono/pull/725) by [@tallshort](https://github.com/tallshort))
- Added `gpt-5.2-codex` models for GitHub Copilot and OpenCode Zen providers ([#734](https://github.com/badlogic/pi-mono/pull/734) by [@aadishv](https://github.com/aadishv))
### Fixed
- Avoid unsigned Gemini 3 tool calls ([#741](https://github.com/badlogic/pi-mono/pull/741) by [@roshanasingh4](https://github.com/roshanasingh4))
- Fixed signature support for non-Anthropic models in Amazon Bedrock provider ([#727](https://github.com/badlogic/pi-mono/pull/727) by [@unexge](https://github.com/unexge))
## [0.45.7] - 2026-01-13
### Fixed
- Fixed OpenAI Responses timeout option handling ([#706](https://github.com/badlogic/pi-mono/pull/706) by [@markusylisiurunen](https://github.com/markusylisiurunen))
- Fixed Bedrock tool call conversion to apply message transforms ([#707](https://github.com/badlogic/pi-mono/pull/707) by [@pjtf93](https://github.com/pjtf93))
## [0.45.6] - 2026-01-13
### Fixed
- Export `parseStreamingJson` from main package for tsx dev mode compatibility
## [0.45.5] - 2026-01-13
## [0.45.4] - 2026-01-13
### Added
- Added Vercel AI Gateway provider with model discovery and `AI_GATEWAY_API_KEY` env support ([#689](https://github.com/badlogic/pi-mono/pull/689) by [@timolins](https://github.com/timolins))
### Fixed
- Fixed z.ai thinking/reasoning: z.ai uses `thinking: { type: "enabled" }` instead of OpenAI's `reasoning_effort`. Added `thinkingFormat` compat flag to handle this. ([#688](https://github.com/badlogic/pi-mono/issues/688))
## [0.45.3] - 2026-01-13
## [0.45.2] - 2026-01-13
## [0.45.1] - 2026-01-13
## [0.45.0] - 2026-01-13
### Added
- MiniMax provider support with M2 and M2.1 models via Anthropic-compatible API ([#656](https://github.com/badlogic/pi-mono/pull/656) by [@dannote](https://github.com/dannote))
- Add Amazon Bedrock provider with prompt caching for Claude models (experimental, tested with Anthropic Claude models only) ([#494](https://github.com/badlogic/pi-mono/pull/494) by [@unexge](https://github.com/unexge))
- Added `serviceTier` option for OpenAI Responses requests ([#672](https://github.com/badlogic/pi-mono/pull/672) by [@markusylisiurunen](https://github.com/markusylisiurunen))
- **Anthropic caching on OpenRouter**: Interactions with Anthropic models via OpenRouter now set a 5-minute cache point using Anthropic-style `cache_control` breakpoints on the last assistant or user message. ([#584](https://github.com/badlogic/pi-mono/pull/584) by [@nathyong](https://github.com/nathyong))
- **Google Gemini CLI provider improvements**: Added Antigravity endpoint fallback (tries daily sandbox then prod when `baseUrl` is unset), header-based retry delay parsing (`Retry-After`, `x-ratelimit-reset`, `x-ratelimit-reset-after`), stable `sessionId` derivation from first user message for cache affinity, empty SSE stream retry with backoff, and `anthropic-beta` header for Claude thinking models ([#670](https://github.com/badlogic/pi-mono/pull/670) by [@kim0](https://github.com/kim0))
## [0.44.0] - 2026-01-12
## [0.43.0] - 2026-01-11
### Fixed
- Fixed Google provider thinking detection: `isThinkingPart()` now only checks `thought === true`, not `thoughtSignature`. Per Google docs, `thoughtSignature` is for context replay and can appear on any part type. Also removed `id` field from `functionCall`/`functionResponse` (rejected by Vertex AI and Cloud Code Assist), and added `textSignature` round-trip for multi-turn reasoning context. ([#631](https://github.com/badlogic/pi-mono/pull/631) by [@theBucky](https://github.com/theBucky))
## [0.42.5] - 2026-01-11
## [0.42.4] - 2026-01-10
## [0.42.3] - 2026-01-10
### Changed
- OpenAI Codex: switched to bundled system prompt matching opencode, changed originator to "pi", simplified prompt handling
## [0.42.2] - 2026-01-10
### Added
- Added `GOOGLE_APPLICATION_CREDENTIALS` env var support for Vertex AI credential detection (standard for CI/production).
- Added `supportsUsageInStreaming` compatibility flag for OpenAI-compatible providers that reject `stream_options: { include_usage: true }`. Defaults to `true`. Set to `false` in model config for providers like gatewayz.ai. ([#596](https://github.com/badlogic/pi-mono/pull/596) by [@XesGaDeus](https://github.com/XesGaDeus))
- Improved Google model pricing info ([#588](https://github.com/badlogic/pi-mono/pull/588) by [@aadishv](https://github.com/aadishv))
### Fixed
- Fixed `os.homedir()` calls at module load time; now resolved lazily when needed.
- Fixed OpenAI Responses tool strict flag to use a boolean for LM Studio compatibility ([#598](https://github.com/badlogic/pi-mono/pull/598) by [@gnattu](https://github.com/gnattu))
- Fixed Google Cloud Code Assist OAuth for paid subscriptions: properly handles long-running operations for project provisioning, supports `GOOGLE_CLOUD_PROJECT` / `GOOGLE_CLOUD_PROJECT_ID` env vars for paid tiers, and handles VPC-SC affected users ([#582](https://github.com/badlogic/pi-mono/pull/582) by [@cmf](https://github.com/cmf))
## [0.42.1] - 2026-01-09
## [0.42.0] - 2026-01-09
### Added
- Added OpenCode Zen provider support with 26 models (Claude, GPT, Gemini, Grok, Kimi, GLM, Qwen, etc.). Set `OPENCODE_API_KEY` env var to use.
## [0.41.0] - 2026-01-09
## [0.40.1] - 2026-01-09
## [0.40.0] - 2026-01-08
## [0.39.1] - 2026-01-08
## [0.39.0] - 2026-01-08
### Fixed
- Fixed Gemini CLI abort handling: detect native `AbortError` in retry catch block, cancel SSE reader when abort signal fires ([#568](https://github.com/badlogic/pi-mono/pull/568) by [@tmustier](https://github.com/tmustier))
- Fixed Antigravity provider 429 errors by aligning request payload with CLIProxyAPI v6.6.89: inject Antigravity system instruction with `role: "user"`, set `requestType: "agent"`, and use `antigravity` userAgent. Added bridge prompt to override Antigravity behavior (identity, paths, web dev guidelines) with Pi defaults. ([#571](https://github.com/badlogic/pi-mono/pull/571) by [@ben-vargas](https://github.com/ben-vargas))
- Fixed thinking block handling for cross-model conversations: thinking blocks are now converted to plain text (no `<thinking>` tags) when switching models. Previously, `<thinking>` tags caused models to mimic the pattern and output literal tags. Also fixed empty thinking blocks causing API errors. ([#561](https://github.com/badlogic/pi-mono/issues/561))
## [0.38.0] - 2026-01-08
### Added
- `thinkingBudgets` option in `SimpleStreamOptions` for customizing token budgets per thinking level on token-based providers ([#529](https://github.com/badlogic/pi-mono/pull/529) by [@melihmucuk](https://github.com/melihmucuk))
### Breaking Changes
- Removed OpenAI Codex model aliases (`gpt-5`, `gpt-5-mini`, `gpt-5-nano`, `codex-mini-latest`, `gpt-5-codex`, `gpt-5.1-codex`, `gpt-5.1-chat-latest`). Use canonical model IDs: `gpt-5.1`, `gpt-5.1-codex-max`, `gpt-5.1-codex-mini`, `gpt-5.2`, `gpt-5.2-codex`. ([#536](https://github.com/badlogic/pi-mono/pull/536) by [@ghoulr](https://github.com/ghoulr))
### Fixed
- Fixed OpenAI Codex context window from 400,000 to 272,000 tokens to match Codex CLI defaults and prevent 400 errors. ([#536](https://github.com/badlogic/pi-mono/pull/536) by [@ghoulr](https://github.com/ghoulr))
- Fixed Codex SSE error events to surface message, code, and status. ([#551](https://github.com/badlogic/pi-mono/pull/551) by [@tmustier](https://github.com/tmustier))
- Fixed context overflow detection for `context_length_exceeded` error codes.
## [0.37.8] - 2026-01-07
## [0.37.7] - 2026-01-07
## [0.37.6] - 2026-01-06
### Added
- Exported OpenAI Codex utilities: `CacheMetadata`, `getCodexInstructions`, `getModelFamily`, `ModelFamily`, `buildCodexPiBridge`, `buildCodexSystemPrompt`, `CodexSystemPrompt` ([#510](https://github.com/badlogic/pi-mono/pull/510) by [@mitsuhiko](https://github.com/mitsuhiko))
## [0.37.5] - 2026-01-06
## [0.37.4] - 2026-01-06
## [0.37.3] - 2026-01-06
### Added
- `sessionId` option in `StreamOptions` for providers that support session-based caching. OpenAI Codex provider uses this to set `prompt_cache_key` and routing headers.
## [0.37.2] - 2026-01-05
### Fixed
- Codex provider now always includes `reasoning.encrypted_content` even when custom `include` options are passed ([#484](https://github.com/badlogic/pi-mono/pull/484) by [@kim0](https://github.com/kim0))
## [0.37.1] - 2026-01-05
## [0.37.0] - 2026-01-05
### Breaking Changes
- OpenAI Codex models no longer have per-thinking-level variants (e.g., `gpt-5.2-codex-high`). Use the base model ID and set thinking level separately. The Codex provider clamps reasoning effort to what each model supports internally. (initial implementation by [@ben-vargas](https://github.com/ben-vargas) in [#472](https://github.com/badlogic/pi-mono/pull/472))
### Added
- Headless OAuth support for all callback-server providers (Google Gemini CLI, Antigravity, OpenAI Codex): paste redirect URL when browser callback is unreachable ([#428](https://github.com/badlogic/pi-mono/pull/428) by [@ben-vargas](https://github.com/ben-vargas), [#468](https://github.com/badlogic/pi-mono/pull/468) by [@crcatala](https://github.com/crcatala))
- Cancellable GitHub Copilot device code polling via AbortSignal
### Fixed
- Codex requests now omit the `reasoning` field entirely when thinking is off, letting the backend use its default instead of forcing a value. ([#472](https://github.com/badlogic/pi-mono/pull/472))
## [0.36.0] - 2026-01-05
### Added
- OpenAI Codex OAuth provider with Responses API streaming support: `openai-codex-responses` streaming provider with SSE parsing, tool-call handling, usage/cost tracking, and PKCE OAuth flow ([#451](https://github.com/badlogic/pi-mono/pull/451) by [@kim0](https://github.com/kim0))
### Fixed
- Vertex AI dummy value for `getEnvApiKey()`: Returns `"<authenticated>"` when Application Default Credentials are configured (`~/.config/gcloud/application_default_credentials.json` exists) and both `GOOGLE_CLOUD_PROJECT` (or `GCLOUD_PROJECT`) and `GOOGLE_CLOUD_LOCATION` are set. This allows `streamSimple()` to work with Vertex AI without explicit `apiKey` option. The ADC credentials file existence check is cached per-process to avoid repeated filesystem access.
## [0.35.0] - 2026-01-05
## [0.34.2] - 2026-01-04
## [0.34.1] - 2026-01-04
## [0.34.0] - 2026-01-04
## [0.33.0] - 2026-01-04
## [0.32.3] - 2026-01-03
### Fixed
- Google Vertex AI models no longer appear in available models list without explicit authentication. Previously, `getEnvApiKey()` returned a dummy value for `google-vertex`, causing models to show up even when Google Cloud ADC was not configured.
## [0.32.2] - 2026-01-03
## [0.32.1] - 2026-01-03
## [0.32.0] - 2026-01-03
### Added
- Vertex AI provider with ADC (Application Default Credentials) support. Authenticate with `gcloud auth application-default login`, set `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`, and access Gemini models via Vertex AI. ([#300](https://github.com/badlogic/pi-mono/pull/300) by [@default-anton](https://github.com/default-anton))
### Fixed
- **Gemini CLI rate limit handling**: Added automatic retry with server-provided delay for 429 errors. Parses delay from error messages like "Your quota will reset after 39s" and waits accordingly. Falls back to exponential backoff for other transient errors. ([#370](https://github.com/badlogic/pi-mono/issues/370))
## [0.31.1] - 2026-01-02
## [0.31.0] - 2026-01-02
### Breaking Changes
- **Agent API moved**: All agent functionality (`agentLoop`, `agentLoopContinue`, `AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, etc.) has moved to `@mariozechner/pi-agent-core`. Import from that package instead of `@mariozechner/pi-ai`.
### Added
- **`GoogleThinkingLevel` type**: Exported type that mirrors Google's `ThinkingLevel` enum values (`"THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH"`). Allows configuring Gemini thinking levels without importing from `@google/genai`.
- **`ANTHROPIC_OAUTH_TOKEN` env var**: Now checked before `ANTHROPIC_API_KEY` in `getEnvApiKey()`, allowing OAuth tokens to take precedence.
- **`event-stream.js` export**: `AssistantMessageEventStream` utility now exported from package index.
### Changed
- **OAuth uses Web Crypto API**: PKCE generation and OAuth flows now use Web Crypto API (`crypto.subtle`) instead of Node.js `crypto` module. This improves browser compatibility while still working in Node.js 20+.
- **Deterministic model generation**: `generate-models.ts` now sorts providers and models alphabetically for consistent output across runs. ([#332](https://github.com/badlogic/pi-mono/pull/332) by [@mrexodia](https://github.com/mrexodia))
### Fixed
- **OpenAI completions empty content blocks**: Empty text or thinking blocks in assistant messages are now filtered out before sending to the OpenAI completions API, preventing validation errors. ([#344](https://github.com/badlogic/pi-mono/pull/344) by [@default-anton](https://github.com/default-anton))
- **Thinking token duplication**: Fixed thinking content duplication with chutes.ai provider. The provider was returning thinking content in both `reasoning_content` and `reasoning` fields, causing each chunk to be processed twice. Now only the first non-empty reasoning field is used.
- **zAi provider API mapping**: Fixed zAi models to use `openai-completions` API with correct base URL (`https://api.z.ai/api/coding/paas/v4`) instead of incorrect Anthropic API mapping. ([#344](https://github.com/badlogic/pi-mono/pull/344), [#358](https://github.com/badlogic/pi-mono/pull/358) by [@default-anton](https://github.com/default-anton))
## [0.28.0] - 2025-12-25
### Breaking Changes
- **OAuth storage removed** ([#296](https://github.com/badlogic/pi-mono/issues/296)): All storage functions (`loadOAuthCredentials`, `saveOAuthCredentials`, `setOAuthStorage`, etc.) removed. Callers are responsible for storing credentials.
- **OAuth login functions**: `loginAnthropic`, `loginGitHubCopilot`, `loginGeminiCli`, `loginAntigravity` now return `OAuthCredentials` instead of saving to disk.
- **refreshOAuthToken**: Now takes `(provider, credentials)` and returns new `OAuthCredentials` instead of saving.
- **getOAuthApiKey**: Now takes `(provider, credentials)` and returns `{ newCredentials, apiKey }` or null.
- **OAuthCredentials type**: No longer includes `type: "oauth"` discriminator. Callers add discriminator when storing.
- **setApiKey, resolveApiKey**: Removed. Callers must manage their own API key storage/resolution.
- **getApiKey**: Renamed to `getEnvApiKey`. Only checks environment variables for known providers.
## [0.27.7] - 2025-12-24
### Fixed
- **Thinking tag leakage**: Fixed Claude mimicking literal `</thinking>` tags in responses. Unsigned thinking blocks (from aborted streams) are now converted to plain text without `<thinking>` tags. The TUI still displays them as thinking blocks. ([#302](https://github.com/badlogic/pi-mono/pull/302) by [@nicobailon](https://github.com/nicobailon))
## [0.25.1] - 2025-12-21
### Added
- **xhigh thinking level support**: Added `supportsXhigh()` function to check if a model supports xhigh reasoning level. Also clamps xhigh to high for OpenAI models that don't support it. ([#236](https://github.com/badlogic/pi-mono/pull/236) by [@theBucky](https://github.com/theBucky))
### Fixed
- **Gemini multimodal tool results**: Fixed images in tool results causing flaky/broken responses with Gemini models. For Gemini 3, images are now nested inside `functionResponse.parts` per the [docs](https://ai.google.dev/gemini-api/docs/function-calling#multimodal). For older models (which don't support multimodal function responses), images are sent in a separate user message.
- **Queued message steering**: When `getQueuedMessages` is provided, the agent loop now checks for queued user messages after each tool call and skips remaining tool calls in the current assistant message when a queued message arrives (emitting error tool results).
- **Double API version path in Google provider URL**: Fixed Gemini API calls returning 404 after baseUrl support was added. The SDK was appending its default apiVersion to baseUrl which already included the version path. ([#251](https://github.com/badlogic/pi-mono/pull/251) by [@shellfyred](https://github.com/shellfyred))
- **Anthropic SDK retries disabled**: Re-enabled SDK-level retries (default 2) for transient HTTP failures. ([#252](https://github.com/badlogic/pi-mono/issues/252))
## [0.23.5] - 2025-12-19
### Added
- **Gemini 3 Flash thinking support**: Extended thinking level support for Gemini 3 Flash models (MINIMAL, LOW, MEDIUM, HIGH) to match Pro models' capabilities. ([#212](https://github.com/badlogic/pi-mono/pull/212) by [@markusylisiurunen](https://github.com/markusylisiurunen))
- **GitHub Copilot thinking models**: Added thinking support for additional Copilot models (o3-mini, o1-mini, o1-preview). ([#234](https://github.com/badlogic/pi-mono/pull/234) by [@aadishv](https://github.com/aadishv))
### Fixed
- **Gemini tool result format**: Fixed tool result format for Gemini 3 Flash Preview which strictly requires `{ output: value }` for success and `{ error: value }` for errors. Previous format using `{ result, isError }` was rejected by newer Gemini models. Also improved type safety by removing `as any` casts. ([#213](https://github.com/badlogic/pi-mono/issues/213), [#220](https://github.com/badlogic/pi-mono/pull/220))
- **Google baseUrl configuration**: Google provider now respects `baseUrl` configuration for custom endpoints or API proxies. ([#216](https://github.com/badlogic/pi-mono/issues/216), [#221](https://github.com/badlogic/pi-mono/pull/221) by [@theBucky](https://github.com/theBucky))
- **GitHub Copilot vision requests**: Added `Copilot-Vision-Request` header when sending images to GitHub Copilot models. ([#222](https://github.com/badlogic/pi-mono/issues/222))
- **GitHub Copilot X-Initiator header**: Fixed X-Initiator logic to check last message role instead of any message in history. This ensures proper billing when users send follow-up messages. ([#209](https://github.com/badlogic/pi-mono/issues/209))
## [0.22.3] - 2025-12-16
### Added
- **Image limits test suite**: Added comprehensive tests for provider-specific image limitations (max images, max size, max dimensions). Discovered actual limits: Anthropic (100 images, 5MB, 8000px), OpenAI (500 images, ≥25MB), Gemini (~2500 images, ≥40MB), Mistral (8 images, ~15MB), OpenRouter (~40 images context-limited, ~15MB). ([#120](https://github.com/badlogic/pi-mono/pull/120))
- **Tool result streaming**: Added `tool_execution_update` event and optional `onUpdate` callback to `AgentTool.execute()` for streaming tool output during execution. Tools can now emit partial results (e.g., bash stdout) that are forwarded to subscribers. ([#44](https://github.com/badlogic/pi-mono/issues/44))
- **X-Initiator header for GitHub Copilot**: Added X-Initiator header handling for GitHub Copilot provider to ensure correct call accounting (agent calls are not deducted from quota). Sets initiator based on last message role. ([#200](https://github.com/badlogic/pi-mono/pull/200) by [@kim0](https://github.com/kim0))
### Changed
- **Normalized tool_execution_end result**: `tool_execution_end` event now always contains `AgentToolResult` (no longer `AgentToolResult | string`). Errors are wrapped in the standard result format.
### Fixed
- **Reasoning disabled by default**: When `reasoning` option is not specified, thinking is now explicitly disabled for all providers. Previously, some providers like Gemini with "dynamic thinking" would use their default (thinking ON), causing unexpected token usage. This was the original intended behavior. ([#180](https://github.com/badlogic/pi-mono/pull/180) by [@markusylisiurunen](https://github.com/markusylisiurunen))
## [0.22.2] - 2025-12-15
### Added
- **Interleaved thinking for Anthropic**: Added `interleavedThinking` option to `AnthropicOptions`. When enabled, Claude 4 models can think between tool calls and reason after receiving tool results. Enabled by default (no extra token cost, just unlocks the capability). Set `interleavedThinking: false` to disable.
## [0.22.1] - 2025-12-15
_Dedicated to Peter's shoulder ([@steipete](https://twitter.com/steipete))_
### Added
- **Interleaved thinking for Anthropic**: Enabled interleaved thinking in the Anthropic provider, allowing Claude models to output thinking blocks interspersed with text responses.
## [0.22.0] - 2025-12-15
### Added
- **GitHub Copilot provider**: Added `github-copilot` as a known provider with models sourced from models.dev. Includes Claude, GPT, Gemini, Grok, and other models available through GitHub Copilot. ([#191](https://github.com/badlogic/pi-mono/pull/191) by [@cau1k](https://github.com/cau1k))
### Fixed
- **GitHub Copilot gpt-5 models**: Fixed API selection for gpt-5 models to use `openai-responses` instead of `openai-completions` (gpt-5 models are not accessible via completions endpoint)
- **GitHub Copilot cross-model context handoff**: Fixed context handoff failing when switching between GitHub Copilot models using different APIs (e.g., gpt-5 to claude-sonnet-4). Tool call IDs from OpenAI Responses API were incompatible with other models. ([#198](https://github.com/badlogic/pi-mono/issues/198))
- **Gemini 3 Pro thinking levels**: Thinking level configuration now works correctly for Gemini 3 Pro models. Previously all levels mapped to -1 (minimal thinking). Now LOW/MEDIUM/HIGH properly control test-time computation. ([#176](https://github.com/badlogic/pi-mono/pull/176) by [@markusylisiurunen](https://github.com/markusylisiurunen))
## [0.18.2] - 2025-12-11
### Changed
- **Anthropic SDK retries disabled**: Set `maxRetries: 0` on Anthropic client to allow application-level retry handling. The SDK's built-in retries were interfering with coding-agent's retry logic. ([#157](https://github.com/badlogic/pi-mono/issues/157))
## [0.18.1] - 2025-12-10
### Added
- **Mistral provider**: Added support for Mistral AI models via the OpenAI-compatible API. Includes automatic handling of Mistral-specific requirements (tool call ID format). Set `MISTRAL_API_KEY` environment variable to use.
### Fixed
- Fixed Mistral 400 errors after aborted assistant messages by skipping empty assistant messages (no content, no tool calls) ([#165](https://github.com/badlogic/pi-mono/issues/165))
- Removed synthetic assistant bridge message after tool results for Mistral (no longer required as of Dec 2025) ([#165](https://github.com/badlogic/pi-mono/issues/165))
- Fixed bug where `ANTHROPIC_API_KEY` environment variable was deleted globally after first OAuth token usage, causing subsequent prompts to fail ([#164](https://github.com/badlogic/pi-mono/pull/164))
## [0.17.0] - 2025-12-09
### Added
- **`agentLoopContinue` function**: Continue an agent loop from existing context without adding a new user message. Validates that the last message is `user` or `toolResult`. Useful for retry after context overflow or resuming from manually-added tool results.
### Breaking Changes
- Removed provider-level tool argument validation. Validation now happens in `agentLoop` via `executeToolCalls`, allowing models to retry on validation errors. For manual tool execution, use `validateToolCall(tools, toolCall)` or `validateToolArguments(tool, toolCall)`.
### Added
- Added `validateToolCall(tools, toolCall)` helper that finds the tool by name and validates arguments.
- **OpenAI compatibility overrides**: Added `compat` field to `Model` for `openai-completions` API, allowing explicit configuration of provider quirks (`supportsStore`, `supportsDeveloperRole`, `supportsReasoningEffort`, `maxTokensField`). Falls back to URL-based detection if not set. Useful for LiteLLM, custom proxies, and other non-standard endpoints. ([#133](https://github.com/badlogic/pi-mono/issues/133), thanks @fink-andreas for the initial idea and PR)
- **xhigh reasoning level**: Added `xhigh` to `ReasoningEffort` type for OpenAI codex-max models. For non-OpenAI providers (Anthropic, Google), `xhigh` is automatically mapped to `high`. ([#143](https://github.com/badlogic/pi-mono/issues/143))
### Changed
- **Updated SDK versions**: OpenAI SDK 5.21.0 → 6.10.0, Anthropic SDK 0.61.0 → 0.71.2, Google GenAI SDK 1.30.0 → 1.31.0
## [0.13.0] - 2025-12-06
### Breaking Changes
- **Added `totalTokens` field to `Usage` type**: All code that constructs `Usage` objects must now include the `totalTokens` field. This field represents the total tokens processed by the LLM (input + output + cache). For OpenAI and Google, this uses native API values (`total_tokens`, `totalTokenCount`). For Anthropic, it's computed as `input + output + cacheRead + cacheWrite`.
## [0.12.10] - 2025-12-04
### Added
- Added `gpt-5.1-codex-max` model support
### Fixed
- **OpenAI Token Counting**: Fixed `usage.input` to exclude cached tokens for OpenAI providers. Previously, `input` included cached tokens, causing double-counting when calculating total context size via `input + cacheRead`. Now `input` represents non-cached input tokens across all providers, making `input + output + cacheRead + cacheWrite` the correct formula for total context size.
- **Fixed Claude Opus 4.5 cache pricing** (was 3x too expensive)
- Corrected cache_read: $1.50 → $0.50 per MTok
- Corrected cache_write: $18.75 → $6.25 per MTok
- Added manual override in `scripts/generate-models.ts` until upstream fix is merged
- Submitted PR to models.dev: https://github.com/sst/models.dev/pull/439
## [0.9.4] - 2025-11-26
Initial release with multi-provider LLM support.

1253
packages/ai/README.md Normal file

File diff suppressed because it is too large Load diff

1
packages/ai/bedrock-provider.d.ts vendored Normal file
View file

@ -0,0 +1 @@
export * from "./dist/bedrock-provider.js";

View file

@ -0,0 +1 @@
export * from "./dist/bedrock-provider.js";

80
packages/ai/package.json Normal file
View file

@ -0,0 +1,80 @@
{
"name": "@mariozechner/pi-ai",
"version": "0.56.2",
"description": "Unified LLM API with automatic model discovery and provider configuration",
"type": "module",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"exports": {
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.js"
},
"./oauth": {
"types": "./dist/oauth.d.ts",
"import": "./dist/oauth.js"
},
"./bedrock-provider": {
"types": "./bedrock-provider.d.ts",
"import": "./bedrock-provider.js"
}
},
"bin": {
"pi-ai": "./dist/cli.js"
},
"files": [
"dist",
"bedrock-provider.js",
"bedrock-provider.d.ts",
"README.md"
],
"scripts": {
"clean": "shx rm -rf dist",
"generate-models": "npx tsx scripts/generate-models.ts",
"build": "npm run generate-models && tsgo -p tsconfig.build.json",
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
"dev:tsc": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
"test": "vitest --run",
"prepublishOnly": "npm run clean && npm run build"
},
"dependencies": {
"@anthropic-ai/sdk": "^0.73.0",
"@aws-sdk/client-bedrock-runtime": "^3.983.0",
"@google/genai": "^1.40.0",
"@mistralai/mistralai": "1.14.1",
"@sinclair/typebox": "^0.34.41",
"ajv": "^8.17.1",
"ajv-formats": "^3.0.1",
"chalk": "^5.6.2",
"openai": "6.26.0",
"partial-json": "^0.1.7",
"proxy-agent": "^6.5.0",
"undici": "^7.19.1",
"zod-to-json-schema": "^3.24.6"
},
"keywords": [
"ai",
"llm",
"openai",
"anthropic",
"gemini",
"bedrock",
"unified",
"api"
],
"author": "Mario Zechner",
"license": "MIT",
"repository": {
"type": "git",
"url": "git+https://github.com/getcompanion-ai/co-mono.git",
"directory": "packages/ai"
},
"engines": {
"node": ">=20.0.0"
},
"devDependencies": {
"@types/node": "^24.3.0",
"canvas": "^3.2.0",
"vitest": "^3.2.4"
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,34 @@
#!/usr/bin/env tsx
import { createCanvas } from "canvas";
import { writeFileSync } from "fs";
import { join, dirname } from "path";
import { fileURLToPath } from "url";
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
// Create a 200x200 canvas
const canvas = createCanvas(200, 200);
const ctx = canvas.getContext("2d");
// Fill background with white
ctx.fillStyle = "white";
ctx.fillRect(0, 0, 200, 200);
// Draw a red circle in the center
ctx.fillStyle = "red";
ctx.beginPath();
ctx.arc(100, 100, 50, 0, Math.PI * 2);
ctx.fill();
// Save the image
const buffer = canvas.toBuffer("image/png");
const outputPath = join(__dirname, "..", "test", "data", "red-circle.png");
// Ensure the directory exists
import { mkdirSync } from "fs";
mkdirSync(join(__dirname, "..", "test", "data"), { recursive: true });
writeFileSync(outputPath, buffer);
console.log(`Generated test image at: ${outputPath}`);

View file

@ -0,0 +1,101 @@
import type {
Api,
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "./types.js";
export type ApiStreamFunction = (
model: Model<Api>,
context: Context,
options?: StreamOptions,
) => AssistantMessageEventStream;
export type ApiStreamSimpleFunction = (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
export interface ApiProvider<
TApi extends Api = Api,
TOptions extends StreamOptions = StreamOptions,
> {
api: TApi;
stream: StreamFunction<TApi, TOptions>;
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
}
interface ApiProviderInternal {
api: Api;
stream: ApiStreamFunction;
streamSimple: ApiStreamSimpleFunction;
}
type RegisteredApiProvider = {
provider: ApiProviderInternal;
sourceId?: string;
};
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
api: TApi,
stream: StreamFunction<TApi, TOptions>,
): ApiStreamFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return stream(model as Model<TApi>, context, options as TOptions);
};
}
function wrapStreamSimple<TApi extends Api>(
api: TApi,
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
): ApiStreamSimpleFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return streamSimple(model as Model<TApi>, context, options);
};
}
export function registerApiProvider<
TApi extends Api,
TOptions extends StreamOptions,
>(provider: ApiProvider<TApi, TOptions>, sourceId?: string): void {
apiProviderRegistry.set(provider.api, {
provider: {
api: provider.api,
stream: wrapStream(provider.api, provider.stream),
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
},
sourceId,
});
}
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
return apiProviderRegistry.get(api)?.provider;
}
export function getApiProviders(): ApiProviderInternal[] {
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
}
export function unregisterApiProviders(sourceId: string): void {
for (const [api, entry] of apiProviderRegistry.entries()) {
if (entry.sourceId === sourceId) {
apiProviderRegistry.delete(api);
}
}
}
export function clearApiProviders(): void {
apiProviderRegistry.clear();
}

View file

@ -0,0 +1,9 @@
import {
streamBedrock,
streamSimpleBedrock,
} from "./providers/amazon-bedrock.js";
export const bedrockProviderModule = {
streamBedrock,
streamSimpleBedrock,
};

152
packages/ai/src/cli.ts Normal file
View file

@ -0,0 +1,152 @@
#!/usr/bin/env node
import { existsSync, readFileSync, writeFileSync } from "fs";
import { createInterface } from "readline";
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
const AUTH_FILE = "auth.json";
const PROVIDERS = getOAuthProviders();
function prompt(
rl: ReturnType<typeof createInterface>,
question: string,
): Promise<string> {
return new Promise((resolve) => rl.question(question, resolve));
}
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
if (!existsSync(AUTH_FILE)) return {};
try {
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
} catch {
return {};
}
}
function saveAuth(
auth: Record<string, { type: "oauth" } & OAuthCredentials>,
): void {
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
}
async function login(providerId: OAuthProviderId): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
console.error(`Unknown provider: ${providerId}`);
process.exit(1);
}
const rl = createInterface({ input: process.stdin, output: process.stdout });
const promptFn = (msg: string) => prompt(rl, `${msg} `);
try {
const credentials = await provider.login({
onAuth: (info) => {
console.log(`\nOpen this URL in your browser:\n${info.url}`);
if (info.instructions) console.log(info.instructions);
console.log();
},
onPrompt: async (p) => {
return await promptFn(
`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`,
);
},
onProgress: (msg) => console.log(msg),
});
const auth = loadAuth();
auth[providerId] = { type: "oauth", ...credentials };
saveAuth(auth);
console.log(`\nCredentials saved to ${AUTH_FILE}`);
} finally {
rl.close();
}
}
async function main(): Promise<void> {
const args = process.argv.slice(2);
const command = args[0];
if (
!command ||
command === "help" ||
command === "--help" ||
command === "-h"
) {
const providerList = PROVIDERS.map(
(p) => ` ${p.id.padEnd(20)} ${p.name}`,
).join("\n");
console.log(`Usage: npx @mariozechner/pi-ai <command> [provider]
Commands:
login [provider] Login to an OAuth provider
list List available providers
Providers:
${providerList}
Examples:
npx @mariozechner/pi-ai login # interactive provider selection
npx @mariozechner/pi-ai login anthropic # login to specific provider
npx @mariozechner/pi-ai list # list providers
`);
return;
}
if (command === "list") {
console.log("Available OAuth providers:\n");
for (const p of PROVIDERS) {
console.log(` ${p.id.padEnd(20)} ${p.name}`);
}
return;
}
if (command === "login") {
let provider = args[1] as OAuthProviderId | undefined;
if (!provider) {
const rl = createInterface({
input: process.stdin,
output: process.stdout,
});
console.log("Select a provider:\n");
for (let i = 0; i < PROVIDERS.length; i++) {
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
}
console.log();
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
rl.close();
const index = parseInt(choice, 10) - 1;
if (index < 0 || index >= PROVIDERS.length) {
console.error("Invalid selection");
process.exit(1);
}
provider = PROVIDERS[index].id;
}
if (!PROVIDERS.some((p) => p.id === provider)) {
console.error(`Unknown provider: ${provider}`);
console.error(
`Use 'npx @mariozechner/pi-ai list' to see available providers`,
);
process.exit(1);
}
console.log(`Logging in to ${provider}...`);
await login(provider);
return;
}
console.error(`Unknown command: ${command}`);
console.error(`Use 'npx @mariozechner/pi-ai --help' for usage`);
process.exit(1);
}
main().catch((err) => {
console.error("Error:", err.message);
process.exit(1);
});

View file

@ -0,0 +1,145 @@
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
let _existsSync: typeof import("node:fs").existsSync | null = null;
let _homedir: typeof import("node:os").homedir | null = null;
let _join: typeof import("node:path").join | null = null;
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const NODE_FS_SPECIFIER = "node:" + "fs";
const NODE_OS_SPECIFIER = "node:" + "os";
const NODE_PATH_SPECIFIER = "node:" + "path";
// Eagerly load in Node.js/Bun environment only
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
_existsSync = (m as typeof import("node:fs")).existsSync;
});
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
_homedir = (m as typeof import("node:os")).homedir;
});
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
_join = (m as typeof import("node:path")).join;
});
}
import type { KnownProvider } from "./types.js";
let cachedVertexAdcCredentialsExists: boolean | null = null;
function hasVertexAdcCredentials(): boolean {
if (cachedVertexAdcCredentialsExists === null) {
// If node modules haven't loaded yet (async import race at startup),
// return false WITHOUT caching so the next call retries once they're ready.
// Only cache false permanently in a browser environment where fs is never available.
if (!_existsSync || !_homedir || !_join) {
const isNode =
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun);
if (!isNode) {
// Definitively in a browser — safe to cache false permanently
cachedVertexAdcCredentialsExists = false;
}
return false;
}
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
if (gacPath) {
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
} else {
// Fall back to default ADC path (lazy evaluation)
cachedVertexAdcCredentialsExists = _existsSync(
_join(
_homedir(),
".config",
"gcloud",
"application_default_credentials.json",
),
);
}
}
return cachedVertexAdcCredentialsExists;
}
/**
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
*
* Will not return API keys for providers that require OAuth tokens.
*/
export function getEnvApiKey(provider: KnownProvider): string | undefined;
export function getEnvApiKey(provider: string): string | undefined;
export function getEnvApiKey(provider: any): string | undefined {
// Fall back to environment variables
if (provider === "github-copilot") {
return (
process.env.COPILOT_GITHUB_TOKEN ||
process.env.GH_TOKEN ||
process.env.GITHUB_TOKEN
);
}
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
if (provider === "anthropic") {
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
}
// Vertex AI uses Application Default Credentials, not API keys.
// Auth is configured via `gcloud auth application-default login`.
if (provider === "google-vertex") {
const hasCredentials = hasVertexAdcCredentials();
const hasProject = !!(
process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT
);
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
if (hasCredentials && hasProject && hasLocation) {
return "<authenticated>";
}
}
if (provider === "amazon-bedrock") {
// Amazon Bedrock supports multiple credential sources:
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
if (
process.env.AWS_PROFILE ||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
) {
return "<authenticated>";
}
}
const envMap: Record<string, string> = {
openai: "OPENAI_API_KEY",
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
google: "GEMINI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
xai: "XAI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
zai: "ZAI_API_KEY",
mistral: "MISTRAL_API_KEY",
minimax: "MINIMAX_API_KEY",
"minimax-cn": "MINIMAX_CN_API_KEY",
huggingface: "HF_TOKEN",
opencode: "OPENCODE_API_KEY",
"opencode-go": "OPENCODE_API_KEY",
"kimi-coding": "KIMI_API_KEY",
};
const envVar = envMap[provider];
return envVar ? process.env[envVar] : undefined;
}

32
packages/ai/src/index.ts Normal file
View file

@ -0,0 +1,32 @@
export type { Static, TSchema } from "@sinclair/typebox";
export { Type } from "@sinclair/typebox";
export * from "./api-registry.js";
export * from "./env-api-keys.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/azure-openai-responses.js";
export * from "./providers/google.js";
export * from "./providers/google-gemini-cli.js";
export * from "./providers/google-vertex.js";
export * from "./providers/mistral.js";
export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./providers/register-builtins.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";
export * from "./utils/json-parse.js";
export type {
OAuthAuthInfo,
OAuthCredentials,
OAuthLoginCallbacks,
OAuthPrompt,
OAuthProvider,
OAuthProviderId,
OAuthProviderInfo,
OAuthProviderInterface,
} from "./utils/oauth/types.js";
export * from "./utils/overflow.js";
export * from "./utils/typebox-helpers.js";
export * from "./utils/validation.js";

File diff suppressed because it is too large Load diff

101
packages/ai/src/models.ts Normal file
View file

@ -0,0 +1,101 @@
import { MODELS } from "./models.generated.js";
import type { Api, KnownProvider, Model, Usage } from "./types.js";
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
// Initialize registry from MODELS on module load
for (const [provider, models] of Object.entries(MODELS)) {
const providerModels = new Map<string, Model<Api>>();
for (const [id, model] of Object.entries(models)) {
providerModels.set(id, model as Model<Api>);
}
modelRegistry.set(provider, providerModels);
}
type ModelApi<
TProvider extends KnownProvider,
TModelId extends keyof (typeof MODELS)[TProvider],
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi }
? TApi extends Api
? TApi
: never
: never;
export function getModel<
TProvider extends KnownProvider,
TModelId extends keyof (typeof MODELS)[TProvider],
>(
provider: TProvider,
modelId: TModelId,
): Model<ModelApi<TProvider, TModelId>> {
const providerModels = modelRegistry.get(provider);
return providerModels?.get(modelId as string) as Model<
ModelApi<TProvider, TModelId>
>;
}
export function getProviders(): KnownProvider[] {
return Array.from(modelRegistry.keys()) as KnownProvider[];
}
export function getModels<TProvider extends KnownProvider>(
provider: TProvider,
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
const models = modelRegistry.get(provider);
return models
? (Array.from(models.values()) as Model<
ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>
>[])
: [];
}
export function calculateCost<TApi extends Api>(
model: Model<TApi>,
usage: Usage,
): Usage["cost"] {
usage.cost.input = (model.cost.input / 1000000) * usage.input;
usage.cost.output = (model.cost.output / 1000000) * usage.output;
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
usage.cost.total =
usage.cost.input +
usage.cost.output +
usage.cost.cacheRead +
usage.cost.cacheWrite;
return usage.cost;
}
/**
* Check if a model supports xhigh thinking level.
*
* Supported today:
* - GPT-5.2 / GPT-5.3 / GPT-5.4 model families
* - Anthropic Messages API Opus 4.6 models (xhigh maps to adaptive effort "max")
*/
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
if (
model.id.includes("gpt-5.2") ||
model.id.includes("gpt-5.3") ||
model.id.includes("gpt-5.4")
) {
return true;
}
if (model.api === "anthropic-messages") {
return model.id.includes("opus-4-6") || model.id.includes("opus-4.6");
}
return false;
}
/**
* Check if two models are equal by comparing both their id and provider.
* Returns false if either model is null or undefined.
*/
export function modelsAreEqual<TApi extends Api>(
a: Model<TApi> | null | undefined,
b: Model<TApi> | null | undefined,
): boolean {
if (!a || !b) return false;
return a.id === b.id && a.provider === b.provider;
}

1
packages/ai/src/oauth.ts Normal file
View file

@ -0,0 +1 @@
export * from "./utils/oauth/index.js";

View file

@ -0,0 +1,894 @@
import {
BedrockRuntimeClient,
type BedrockRuntimeClientConfig,
StopReason as BedrockStopReason,
type Tool as BedrockTool,
CachePointType,
CacheTTL,
type ContentBlock,
type ContentBlockDeltaEvent,
type ContentBlockStartEvent,
type ContentBlockStopEvent,
ConversationRole,
ConverseStreamCommand,
type ConverseStreamMetadataEvent,
ImageFormat,
type Message,
type SystemContentBlock,
type ToolChoice,
type ToolConfiguration,
ToolResultStatus,
} from "@aws-sdk/client-bedrock-runtime";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
adjustMaxTokensForThinking,
buildBaseOptions,
clampReasoning,
} from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
export interface BedrockOptions extends StreamOptions {
region?: string;
profile?: string;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
reasoning?: ThinkingLevel;
/* Custom token budgets per thinking level. Overrides default budgets. */
thinkingBudgets?: ThinkingBudgets;
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
interleavedThinking?: boolean;
}
type Block = (TextContent | ThinkingContent | ToolCall) & {
index?: number;
partialJson?: string;
};
export const streamBedrock: StreamFunction<
"bedrock-converse-stream",
BedrockOptions
> = (
model: Model<"bedrock-converse-stream">,
context: Context,
options: BedrockOptions = {},
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "bedrock-converse-stream" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
const blocks = output.content as Block[];
const config: BedrockRuntimeClientConfig = {
profile: options.profile,
};
// in Node.js/Bun environment only
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
// Region resolution: explicit option > env vars > SDK default chain.
// When AWS_PROFILE is set, we leave region undefined so the SDK can
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
const explicitRegion =
options.region ||
process.env.AWS_REGION ||
process.env.AWS_DEFAULT_REGION;
if (explicitRegion) {
config.region = explicitRegion;
} else if (!process.env.AWS_PROFILE) {
config.region = "us-east-1";
}
// Support proxies that don't need authentication
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
config.credentials = {
accessKeyId: "dummy-access-key",
secretAccessKey: "dummy-secret-key",
};
}
if (
process.env.HTTP_PROXY ||
process.env.HTTPS_PROXY ||
process.env.NO_PROXY ||
process.env.http_proxy ||
process.env.https_proxy ||
process.env.no_proxy
) {
const nodeHttpHandler = await import("@smithy/node-http-handler");
const proxyAgent = await import("proxy-agent");
const agent = new proxyAgent.ProxyAgent();
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
// on `http2` module and has no support for http agent.
// Use NodeHttpHandler to support http agent.
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
httpAgent: agent,
httpsAgent: agent,
});
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
const nodeHttpHandler = await import("@smithy/node-http-handler");
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
}
} else {
// Non-Node environment (browser): fall back to us-east-1 since
// there's no config file resolution available.
config.region = options.region || "us-east-1";
}
try {
const client = new BedrockRuntimeClient(config);
const cacheRetention = resolveCacheRetention(options.cacheRetention);
const commandInput = {
modelId: model.id,
messages: convertMessages(context, model, cacheRetention),
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
inferenceConfig: {
maxTokens: options.maxTokens,
temperature: options.temperature,
},
toolConfig: convertToolConfig(context.tools, options.toolChoice),
additionalModelRequestFields: buildAdditionalModelRequestFields(
model,
options,
),
};
options?.onPayload?.(commandInput);
const command = new ConverseStreamCommand(commandInput);
const response = await client.send(command, {
abortSignal: options.signal,
});
for await (const item of response.stream!) {
if (item.messageStart) {
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
throw new Error(
"Unexpected assistant message start but got user message start instead",
);
}
stream.push({ type: "start", partial: output });
} else if (item.contentBlockStart) {
handleContentBlockStart(
item.contentBlockStart,
blocks,
output,
stream,
);
} else if (item.contentBlockDelta) {
handleContentBlockDelta(
item.contentBlockDelta,
blocks,
output,
stream,
);
} else if (item.contentBlockStop) {
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
} else if (item.messageStop) {
output.stopReason = mapStopReason(item.messageStop.stopReason);
} else if (item.metadata) {
handleMetadata(item.metadata, model, output);
} else if (item.internalServerException) {
throw new Error(
`Internal server error: ${item.internalServerException.message}`,
);
} else if (item.modelStreamErrorException) {
throw new Error(
`Model stream error: ${item.modelStreamErrorException.message}`,
);
} else if (item.validationException) {
throw new Error(
`Validation error: ${item.validationException.message}`,
);
} else if (item.throttlingException) {
throw new Error(
`Throttling error: ${item.throttlingException.message}`,
);
} else if (item.serviceUnavailableException) {
throw new Error(
`Service unavailable: ${item.serviceUnavailableException.message}`,
);
}
}
if (options.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "error" || output.stopReason === "aborted") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) {
delete (block as Block).index;
delete (block as Block).partialJson;
}
output.stopReason = options.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleBedrock: StreamFunction<
"bedrock-converse-stream",
SimpleStreamOptions
> = (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const base = buildBaseOptions(model, options, undefined);
if (!options?.reasoning) {
return streamBedrock(model, context, {
...base,
reasoning: undefined,
} satisfies BedrockOptions);
}
if (
model.id.includes("anthropic.claude") ||
model.id.includes("anthropic/claude")
) {
if (supportsAdaptiveThinking(model.id)) {
return streamBedrock(model, context, {
...base,
reasoning: options.reasoning,
thinkingBudgets: options.thinkingBudgets,
} satisfies BedrockOptions);
}
const adjusted = adjustMaxTokensForThinking(
base.maxTokens || 0,
model.maxTokens,
options.reasoning,
options.thinkingBudgets,
);
return streamBedrock(model, context, {
...base,
maxTokens: adjusted.maxTokens,
reasoning: options.reasoning,
thinkingBudgets: {
...(options.thinkingBudgets || {}),
[clampReasoning(options.reasoning)!]: adjusted.thinkingBudget,
},
} satisfies BedrockOptions);
}
return streamBedrock(model, context, {
...base,
reasoning: options.reasoning,
thinkingBudgets: options.thinkingBudgets,
} satisfies BedrockOptions);
};
function handleContentBlockStart(
event: ContentBlockStartEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const index = event.contentBlockIndex!;
const start = event.start;
if (start?.toolUse) {
const block: Block = {
type: "toolCall",
id: start.toolUse.toolUseId || "",
name: start.toolUse.name || "",
arguments: {},
partialJson: "",
index,
};
output.content.push(block);
stream.push({
type: "toolcall_start",
contentIndex: blocks.length - 1,
partial: output,
});
}
}
function handleContentBlockDelta(
event: ContentBlockDeltaEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const contentBlockIndex = event.contentBlockIndex!;
const delta = event.delta;
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
let block = blocks[index];
if (delta?.text !== undefined) {
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
if (!block) {
const newBlock: Block = {
type: "text",
text: "",
index: contentBlockIndex,
};
output.content.push(newBlock);
index = blocks.length - 1;
block = blocks[index];
stream.push({ type: "text_start", contentIndex: index, partial: output });
}
if (block.type === "text") {
block.text += delta.text;
stream.push({
type: "text_delta",
contentIndex: index,
delta: delta.text,
partial: output,
});
}
} else if (delta?.toolUse && block?.type === "toolCall") {
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
block.arguments = parseStreamingJson(block.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: delta.toolUse.input || "",
partial: output,
});
} else if (delta?.reasoningContent) {
let thinkingBlock = block;
let thinkingIndex = index;
if (!thinkingBlock) {
const newBlock: Block = {
type: "thinking",
thinking: "",
thinkingSignature: "",
index: contentBlockIndex,
};
output.content.push(newBlock);
thinkingIndex = blocks.length - 1;
thinkingBlock = blocks[thinkingIndex];
stream.push({
type: "thinking_start",
contentIndex: thinkingIndex,
partial: output,
});
}
if (thinkingBlock?.type === "thinking") {
if (delta.reasoningContent.text) {
thinkingBlock.thinking += delta.reasoningContent.text;
stream.push({
type: "thinking_delta",
contentIndex: thinkingIndex,
delta: delta.reasoningContent.text,
partial: output,
});
}
if (delta.reasoningContent.signature) {
thinkingBlock.thinkingSignature =
(thinkingBlock.thinkingSignature || "") +
delta.reasoningContent.signature;
}
}
}
}
function handleMetadata(
event: ConverseStreamMetadataEvent,
model: Model<"bedrock-converse-stream">,
output: AssistantMessage,
): void {
if (event.usage) {
output.usage.input = event.usage.inputTokens || 0;
output.usage.output = event.usage.outputTokens || 0;
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
output.usage.totalTokens =
event.usage.totalTokens || output.usage.input + output.usage.output;
calculateCost(model, output.usage);
}
}
function handleContentBlockStop(
event: ContentBlockStopEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
const block = blocks[index];
if (!block) return;
delete (block as Block).index;
switch (block.type) {
case "text":
stream.push({
type: "text_end",
contentIndex: index,
content: block.text,
partial: output,
});
break;
case "thinking":
stream.push({
type: "thinking_end",
contentIndex: index,
content: block.thinking,
partial: output,
});
break;
case "toolCall":
block.arguments = parseStreamingJson(block.partialJson);
delete (block as Block).partialJson;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: block,
partial: output,
});
break;
}
}
/**
* Check if the model supports adaptive thinking (Opus 4.6 and Sonnet 4.6).
*/
function supportsAdaptiveThinking(modelId: string): boolean {
return (
modelId.includes("opus-4-6") ||
modelId.includes("opus-4.6") ||
modelId.includes("sonnet-4-6") ||
modelId.includes("sonnet-4.6")
);
}
function mapThinkingLevelToEffort(
level: SimpleStreamOptions["reasoning"],
modelId: string,
): "low" | "medium" | "high" | "max" {
switch (level) {
case "minimal":
case "low":
return "low";
case "medium":
return "medium";
case "high":
return "high";
case "xhigh":
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
? "max"
: "high";
default:
return "high";
}
}
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(
cacheRetention?: CacheRetention,
): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long"
) {
return "long";
}
return "short";
}
/**
* Check if the model supports prompt caching.
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
*/
function supportsPromptCaching(
model: Model<"bedrock-converse-stream">,
): boolean {
if (model.cost.cacheRead || model.cost.cacheWrite) {
return true;
}
const id = model.id.toLowerCase();
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4.")))
return true;
// Claude 3.7 Sonnet
if (id.includes("claude-3-7-sonnet")) return true;
// Claude 3.5 Haiku
if (id.includes("claude-3-5-haiku")) return true;
return false;
}
/**
* Check if the model supports thinking signatures in reasoningContent.
* Only Anthropic Claude models support the signature field.
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
* "This model doesn't support the reasoningContent.reasoningText.signature field"
*/
function supportsThinkingSignature(
model: Model<"bedrock-converse-stream">,
): boolean {
const id = model.id.toLowerCase();
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
}
function buildSystemPrompt(
systemPrompt: string | undefined,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): SystemContentBlock[] | undefined {
if (!systemPrompt) return undefined;
const blocks: SystemContentBlock[] = [
{ text: sanitizeSurrogates(systemPrompt) },
];
// Add cache point for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
blocks.push({
cachePoint: {
type: CachePointType.DEFAULT,
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
},
});
}
return blocks;
}
function normalizeToolCallId(id: string): string {
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
}
function convertMessages(
context: Context,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): Message[] {
const result: Message[] = [];
const transformedMessages = transformMessages(
context.messages,
model,
normalizeToolCallId,
);
for (let i = 0; i < transformedMessages.length; i++) {
const m = transformedMessages[i];
switch (m.role) {
case "user":
result.push({
role: ConversationRole.USER,
content:
typeof m.content === "string"
? [{ text: sanitizeSurrogates(m.content) }]
: m.content.map((c) => {
switch (c.type) {
case "text":
return { text: sanitizeSurrogates(c.text) };
case "image":
return { image: createImageBlock(c.mimeType, c.data) };
default:
throw new Error("Unknown user content type");
}
}),
});
break;
case "assistant": {
// Skip assistant messages with empty content (e.g., from aborted requests)
// Bedrock rejects messages with empty content arrays
if (m.content.length === 0) {
continue;
}
const contentBlocks: ContentBlock[] = [];
for (const c of m.content) {
switch (c.type) {
case "text":
// Skip empty text blocks
if (c.text.trim().length === 0) continue;
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
break;
case "toolCall":
contentBlocks.push({
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
});
break;
case "thinking":
// Skip empty thinking blocks
if (c.thinking.trim().length === 0) continue;
// Only Anthropic models support the signature field in reasoningText.
// For other models, we omit the signature to avoid errors like:
// "This model doesn't support the reasoningContent.reasoningText.signature field"
if (supportsThinkingSignature(model)) {
contentBlocks.push({
reasoningContent: {
reasoningText: {
text: sanitizeSurrogates(c.thinking),
signature: c.thinkingSignature,
},
},
});
} else {
contentBlocks.push({
reasoningContent: {
reasoningText: { text: sanitizeSurrogates(c.thinking) },
},
});
}
break;
default:
throw new Error("Unknown assistant content type");
}
}
// Skip if all content blocks were filtered out
if (contentBlocks.length === 0) {
continue;
}
result.push({
role: ConversationRole.ASSISTANT,
content: contentBlocks,
});
break;
}
case "toolResult": {
// Collect all consecutive toolResult messages into a single user message
// Bedrock requires all tool results to be in one message
const toolResults: ContentBlock.ToolResultMember[] = [];
// Add current tool result with all content blocks combined
toolResults.push({
toolResult: {
toolUseId: m.toolCallId,
content: m.content.map((c) =>
c.type === "image"
? { image: createImageBlock(c.mimeType, c.data) }
: { text: sanitizeSurrogates(c.text) },
),
status: m.isError
? ToolResultStatus.ERROR
: ToolResultStatus.SUCCESS,
},
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (
j < transformedMessages.length &&
transformedMessages[j].role === "toolResult"
) {
const nextMsg = transformedMessages[j] as ToolResultMessage;
toolResults.push({
toolResult: {
toolUseId: nextMsg.toolCallId,
content: nextMsg.content.map((c) =>
c.type === "image"
? { image: createImageBlock(c.mimeType, c.data) }
: { text: sanitizeSurrogates(c.text) },
),
status: nextMsg.isError
? ToolResultStatus.ERROR
: ToolResultStatus.SUCCESS,
},
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
result.push({
role: ConversationRole.USER,
content: toolResults,
});
break;
}
default:
throw new Error("Unknown message role");
}
}
// Add cache point to the last user message for supported Claude models when caching is enabled
if (
cacheRetention !== "none" &&
supportsPromptCaching(model) &&
result.length > 0
) {
const lastMessage = result[result.length - 1];
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
(lastMessage.content as ContentBlock[]).push({
cachePoint: {
type: CachePointType.DEFAULT,
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
},
});
}
}
return result;
}
function convertToolConfig(
tools: Tool[] | undefined,
toolChoice: BedrockOptions["toolChoice"],
): ToolConfiguration | undefined {
if (!tools?.length || toolChoice === "none") return undefined;
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
toolSpec: {
name: tool.name,
description: tool.description,
inputSchema: { json: tool.parameters },
},
}));
let bedrockToolChoice: ToolChoice | undefined;
switch (toolChoice) {
case "auto":
bedrockToolChoice = { auto: {} };
break;
case "any":
bedrockToolChoice = { any: {} };
break;
default:
if (toolChoice?.type === "tool") {
bedrockToolChoice = { tool: { name: toolChoice.name } };
}
}
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
}
function mapStopReason(reason: string | undefined): StopReason {
switch (reason) {
case BedrockStopReason.END_TURN:
case BedrockStopReason.STOP_SEQUENCE:
return "stop";
case BedrockStopReason.MAX_TOKENS:
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
return "length";
case BedrockStopReason.TOOL_USE:
return "toolUse";
default:
return "error";
}
}
function buildAdditionalModelRequestFields(
model: Model<"bedrock-converse-stream">,
options: BedrockOptions,
): Record<string, any> | undefined {
if (!options.reasoning || !model.reasoning) {
return undefined;
}
if (
model.id.includes("anthropic.claude") ||
model.id.includes("anthropic/claude")
) {
const result: Record<string, any> = supportsAdaptiveThinking(model.id)
? {
thinking: { type: "adaptive" },
output_config: {
effort: mapThinkingLevelToEffort(options.reasoning, model.id),
},
}
: (() => {
const defaultBudgets: Record<ThinkingLevel, number> = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
};
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
const level =
options.reasoning === "xhigh" ? "high" : options.reasoning;
const budget =
options.thinkingBudgets?.[level] ??
defaultBudgets[options.reasoning];
return {
thinking: {
type: "enabled",
budget_tokens: budget,
},
};
})();
if (
!supportsAdaptiveThinking(model.id) &&
(options.interleavedThinking ?? true)
) {
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
}
return result;
}
return undefined;
}
function createImageBlock(mimeType: string, data: string) {
let format: ImageFormat;
switch (mimeType) {
case "image/jpeg":
case "image/jpg":
format = ImageFormat.JPEG;
break;
case "image/png":
format = ImageFormat.PNG;
break;
case "image/gif":
format = ImageFormat.GIF;
break;
case "image/webp":
format = ImageFormat.WEBP;
break;
default:
throw new Error(`Unknown image type: ${mimeType}`);
}
const binaryString = atob(data);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return { source: { bytes }, format };
}

View file

@ -0,0 +1,989 @@
import Anthropic from "@anthropic-ai/sdk";
import type {
ContentBlockParam,
MessageCreateParamsStreaming,
MessageParam,
} from "@anthropic-ai/sdk/resources/messages.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
ImageContent,
Message,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
buildCopilotDynamicHeaders,
hasCopilotVisionInput,
} from "./github-copilot-headers.js";
import {
adjustMaxTokensForThinking,
buildBaseOptions,
} from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(
cacheRetention?: CacheRetention,
): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long"
) {
return "long";
}
return "short";
}
function getCacheControl(
baseUrl: string,
cacheRetention?: CacheRetention,
): {
retention: CacheRetention;
cacheControl?: { type: "ephemeral"; ttl?: "1h" };
} {
const retention = resolveCacheRetention(cacheRetention);
if (retention === "none") {
return { retention };
}
const ttl =
retention === "long" && baseUrl.includes("api.anthropic.com")
? "1h"
: undefined;
return {
retention,
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
};
}
// Stealth mode: Mimic Claude Code's tool naming exactly
const claudeCodeVersion = "2.1.62";
// Claude Code 2.x tool names (canonical casing)
// Source: https://cchistory.mariozechner.at/data/prompts-2.1.11.md
// To update: https://github.com/badlogic/cchistory
const claudeCodeTools = [
"Read",
"Write",
"Edit",
"Bash",
"Grep",
"Glob",
"AskUserQuestion",
"EnterPlanMode",
"ExitPlanMode",
"KillShell",
"NotebookEdit",
"Skill",
"Task",
"TaskOutput",
"TodoWrite",
"WebFetch",
"WebSearch",
];
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
// Convert tool name to CC canonical casing if it matches (case-insensitive)
const toClaudeCodeName = (name: string) =>
ccToolLookup.get(name.toLowerCase()) ?? name;
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
if (tools && tools.length > 0) {
const lowerName = name.toLowerCase();
const matchedTool = tools.find(
(tool) => tool.name.toLowerCase() === lowerName,
);
if (matchedTool) return matchedTool.name;
}
return name;
};
/**
* Convert content blocks to Anthropic API format
*/
function convertContentBlocks(content: (TextContent | ImageContent)[]):
| string
| Array<
| { type: "text"; text: string }
| {
type: "image";
source: {
type: "base64";
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
data: string;
};
}
> {
// If only text blocks, return as concatenated string for simplicity
const hasImages = content.some((c) => c.type === "image");
if (!hasImages) {
return sanitizeSurrogates(
content.map((c) => (c as TextContent).text).join("\n"),
);
}
// If we have images, convert to content block array
const blocks = content.map((block) => {
if (block.type === "text") {
return {
type: "text" as const,
text: sanitizeSurrogates(block.text),
};
}
return {
type: "image" as const,
source: {
type: "base64" as const,
media_type: block.mimeType as
| "image/jpeg"
| "image/png"
| "image/gif"
| "image/webp",
data: block.data,
},
};
});
// If only images (no text), add placeholder text block
const hasText = blocks.some((b) => b.type === "text");
if (!hasText) {
blocks.unshift({
type: "text" as const,
text: "(see attached image)",
});
}
return blocks;
}
export type AnthropicEffort = "low" | "medium" | "high" | "max";
export interface AnthropicOptions extends StreamOptions {
/**
* Enable extended thinking.
* For Opus 4.6 and Sonnet 4.6: uses adaptive thinking (model decides when/how much to think).
* For older models: uses budget-based thinking with thinkingBudgetTokens.
*/
thinkingEnabled?: boolean;
/**
* Token budget for extended thinking (older models only).
* Ignored for Opus 4.6 and Sonnet 4.6, which use adaptive thinking.
*/
thinkingBudgetTokens?: number;
/**
* Effort level for adaptive thinking (Opus 4.6 and Sonnet 4.6).
* Controls how much thinking Claude allocates:
* - "max": Always thinks with no constraints (Opus 4.6 only)
* - "high": Always thinks, deep reasoning (default)
* - "medium": Moderate thinking, may skip for simple queries
* - "low": Minimal thinking, skips for simple tasks
* Ignored for older models.
*/
effort?: AnthropicEffort;
interleavedThinking?: boolean;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
}
function mergeHeaders(
...headerSources: (Record<string, string> | undefined)[]
): Record<string, string> {
const merged: Record<string, string> = {};
for (const headers of headerSources) {
if (headers) {
Object.assign(merged, headers);
}
}
return merged;
}
export const streamAnthropic: StreamFunction<
"anthropic-messages",
AnthropicOptions
> = (
model: Model<"anthropic-messages">,
context: Context,
options?: AnthropicOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
let copilotDynamicHeaders: Record<string, string> | undefined;
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
copilotDynamicHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
}
const { client, isOAuthToken } = createClient(
model,
apiKey,
options?.interleavedThinking ?? true,
options?.headers,
copilotDynamicHeaders,
);
const params = buildParams(model, context, isOAuthToken, options);
options?.onPayload?.(params);
const anthropicStream = client.messages.stream(
{ ...params, stream: true },
{ signal: options?.signal },
);
stream.push({ type: "start", partial: output });
type Block = (
| ThinkingContent
| TextContent
| (ToolCall & { partialJson: string })
) & { index: number };
const blocks = output.content as Block[];
for await (const event of anthropicStream) {
if (event.type === "message_start") {
// Capture initial token usage from message_start event
// This ensures we have input token counts even if the stream is aborted early
output.usage.input = event.message.usage.input_tokens || 0;
output.usage.output = event.message.usage.output_tokens || 0;
output.usage.cacheRead =
event.message.usage.cache_read_input_tokens || 0;
output.usage.cacheWrite =
event.message.usage.cache_creation_input_tokens || 0;
// Anthropic doesn't provide total_tokens, compute from components
output.usage.totalTokens =
output.usage.input +
output.usage.output +
output.usage.cacheRead +
output.usage.cacheWrite;
calculateCost(model, output.usage);
} else if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
const block: Block = {
type: "text",
text: "",
index: event.index,
};
output.content.push(block);
stream.push({
type: "text_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block.type === "thinking") {
const block: Block = {
type: "thinking",
thinking: "",
thinkingSignature: "",
index: event.index,
};
output.content.push(block);
stream.push({
type: "thinking_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block.type === "redacted_thinking") {
const block: Block = {
type: "thinking",
thinking: "[Reasoning redacted]",
thinkingSignature: event.content_block.data,
redacted: true,
index: event.index,
};
output.content.push(block);
stream.push({
type: "thinking_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block.type === "tool_use") {
const block: Block = {
type: "toolCall",
id: event.content_block.id,
name: isOAuthToken
? fromClaudeCodeName(event.content_block.name, context.tools)
: event.content_block.name,
arguments:
(event.content_block.input as Record<string, any>) ?? {},
partialJson: "",
index: event.index,
};
output.content.push(block);
stream.push({
type: "toolcall_start",
contentIndex: output.content.length - 1,
partial: output,
});
}
} else if (event.type === "content_block_delta") {
if (event.delta.type === "text_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "text") {
block.text += event.delta.text;
stream.push({
type: "text_delta",
contentIndex: index,
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta.type === "thinking_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
contentIndex: index,
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta.type === "input_json_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "toolCall") {
block.partialJson += event.delta.partial_json;
block.arguments = parseStreamingJson(block.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: event.delta.partial_json,
partial: output,
});
}
} else if (event.delta.type === "signature_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinkingSignature = block.thinkingSignature || "";
block.thinkingSignature += event.delta.signature;
}
}
} else if (event.type === "content_block_stop") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block) {
delete (block as any).index;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: index,
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: index,
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
block.arguments = parseStreamingJson(block.partialJson);
delete (block as any).partialJson;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: block,
partial: output,
});
}
}
} else if (event.type === "message_delta") {
if (event.delta.stop_reason) {
output.stopReason = mapStopReason(event.delta.stop_reason);
}
// Only update usage fields if present (not null).
// Preserves input_tokens from message_start when proxies omit it in message_delta.
if (event.usage.input_tokens != null) {
output.usage.input = event.usage.input_tokens;
}
if (event.usage.output_tokens != null) {
output.usage.output = event.usage.output_tokens;
}
if (event.usage.cache_read_input_tokens != null) {
output.usage.cacheRead = event.usage.cache_read_input_tokens;
}
if (event.usage.cache_creation_input_tokens != null) {
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
}
// Anthropic doesn't provide total_tokens, compute from components
output.usage.totalTokens =
output.usage.input +
output.usage.output +
output.usage.cacheRead +
output.usage.cacheWrite;
calculateCost(model, output.usage);
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
/**
* Check if a model supports adaptive thinking (Opus 4.6 and Sonnet 4.6)
*/
function supportsAdaptiveThinking(modelId: string): boolean {
// Opus 4.6 and Sonnet 4.6 model IDs (with or without date suffix)
return (
modelId.includes("opus-4-6") ||
modelId.includes("opus-4.6") ||
modelId.includes("sonnet-4-6") ||
modelId.includes("sonnet-4.6")
);
}
/**
* Map ThinkingLevel to Anthropic effort levels for adaptive thinking.
* Note: effort "max" is only valid on Opus 4.6.
*/
function mapThinkingLevelToEffort(
level: SimpleStreamOptions["reasoning"],
modelId: string,
): AnthropicEffort {
switch (level) {
case "minimal":
return "low";
case "low":
return "low";
case "medium":
return "medium";
case "high":
return "high";
case "xhigh":
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
? "max"
: "high";
default:
return "high";
}
}
export const streamSimpleAnthropic: StreamFunction<
"anthropic-messages",
SimpleStreamOptions
> = (
model: Model<"anthropic-messages">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamAnthropic(model, context, {
...base,
thinkingEnabled: false,
} satisfies AnthropicOptions);
}
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
// For older models: use budget-based thinking
if (supportsAdaptiveThinking(model.id)) {
const effort = mapThinkingLevelToEffort(options.reasoning, model.id);
return streamAnthropic(model, context, {
...base,
thinkingEnabled: true,
effort,
} satisfies AnthropicOptions);
}
const adjusted = adjustMaxTokensForThinking(
base.maxTokens || 0,
model.maxTokens,
options.reasoning,
options.thinkingBudgets,
);
return streamAnthropic(model, context, {
...base,
maxTokens: adjusted.maxTokens,
thinkingEnabled: true,
thinkingBudgetTokens: adjusted.thinkingBudget,
} satisfies AnthropicOptions);
};
function isOAuthToken(apiKey: string): boolean {
return apiKey.includes("sk-ant-oat");
}
function createClient(
model: Model<"anthropic-messages">,
apiKey: string,
interleavedThinking: boolean,
optionsHeaders?: Record<string, string>,
dynamicHeaders?: Record<string, string>,
): { client: Anthropic; isOAuthToken: boolean } {
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
const needsInterleavedBeta =
interleavedThinking && !supportsAdaptiveThinking(model.id);
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
if (model.provider === "github-copilot") {
const betaFeatures: string[] = [];
if (needsInterleavedBeta) {
betaFeatures.push("interleaved-thinking-2025-05-14");
}
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
...(betaFeatures.length > 0
? { "anthropic-beta": betaFeatures.join(",") }
: {}),
},
model.headers,
dynamicHeaders,
optionsHeaders,
),
});
return { client, isOAuthToken: false };
}
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"];
if (needsInterleavedBeta) {
betaFeatures.push("interleaved-thinking-2025-05-14");
}
// OAuth: Bearer auth, Claude Code identity headers
if (isOAuthToken(apiKey)) {
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
"user-agent": `claude-cli/${claudeCodeVersion}`,
"x-app": "cli",
},
model.headers,
optionsHeaders,
),
});
return { client, isOAuthToken: true };
}
// API key auth
const client = new Anthropic({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": betaFeatures.join(","),
},
model.headers,
optionsHeaders,
),
});
return { client, isOAuthToken: false };
}
function buildParams(
model: Model<"anthropic-messages">,
context: Context,
isOAuthToken: boolean,
options?: AnthropicOptions,
): MessageCreateParamsStreaming {
const { cacheControl } = getCacheControl(
model.baseUrl,
options?.cacheRetention,
);
const params: MessageCreateParamsStreaming = {
model: model.id,
messages: convertMessages(
context.messages,
model,
isOAuthToken,
cacheControl,
),
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
});
}
} else if (context.systemPrompt) {
// Add cache control to system prompt for non-OAuth tokens
params.system = [
{
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
}
// Temperature is incompatible with extended thinking (adaptive or budget-based).
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools, isOAuthToken);
}
// Configure thinking mode: adaptive (Opus 4.6 and Sonnet 4.6) or budget-based (older models)
if (options?.thinkingEnabled && model.reasoning) {
if (supportsAdaptiveThinking(model.id)) {
// Adaptive thinking: Claude decides when and how much to think
params.thinking = { type: "adaptive" };
if (options.effort) {
params.output_config = { effort: options.effort };
}
} else {
// Budget-based thinking for older models
params.thinking = {
type: "enabled",
budget_tokens: options.thinkingBudgetTokens || 1024,
};
}
}
if (options?.metadata) {
const userId = options.metadata.user_id;
if (typeof userId === "string") {
params.metadata = { user_id: userId };
}
}
if (options?.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
}
}
return params;
}
// Normalize tool call IDs to match Anthropic's required pattern and length
function normalizeToolCallId(id: string): string {
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
}
function convertMessages(
messages: Message[],
model: Model<"anthropic-messages">,
isOAuthToken: boolean,
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
): MessageParam[] {
const params: MessageParam[] = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(
messages,
model,
normalizeToolCallId,
);
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
if (msg.role === "user") {
if (typeof msg.content === "string") {
if (msg.content.trim().length > 0) {
params.push({
role: "user",
content: sanitizeSurrogates(msg.content),
});
}
} else {
const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
return {
type: "text",
text: sanitizeSurrogates(item.text),
};
} else {
return {
type: "image",
source: {
type: "base64",
media_type: item.mimeType as
| "image/jpeg"
| "image/png"
| "image/gif"
| "image/webp",
data: item.data,
},
};
}
});
let filteredBlocks = !model?.input.includes("image")
? blocks.filter((b) => b.type !== "image")
: blocks;
filteredBlocks = filteredBlocks.filter((b) => {
if (b.type === "text") {
return b.text.trim().length > 0;
}
return true;
});
if (filteredBlocks.length === 0) continue;
params.push({
role: "user",
content: filteredBlocks,
});
}
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length === 0) continue;
blocks.push({
type: "text",
text: sanitizeSurrogates(block.text),
});
} else if (block.type === "thinking") {
// Redacted thinking: pass the opaque payload back as redacted_thinking
if (block.redacted) {
blocks.push({
type: "redacted_thinking",
data: block.thinkingSignature!,
});
continue;
}
if (block.thinking.trim().length === 0) continue;
// If thinking signature is missing/empty (e.g., from aborted stream),
// convert to plain text block without <thinking> tags to avoid API rejection
// and prevent Claude from mimicking the tags in responses
if (
!block.thinkingSignature ||
block.thinkingSignature.trim().length === 0
) {
blocks.push({
type: "text",
text: sanitizeSurrogates(block.thinking),
});
} else {
blocks.push({
type: "thinking",
thinking: sanitizeSurrogates(block.thinking),
signature: block.thinkingSignature,
});
}
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
input: block.arguments ?? {},
});
}
}
if (blocks.length === 0) continue;
params.push({
role: "assistant",
content: blocks,
});
} else if (msg.role === "toolResult") {
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
const toolResults: ContentBlockParam[] = [];
// Add the current tool result
toolResults.push({
type: "tool_result",
tool_use_id: msg.toolCallId,
content: convertContentBlocks(msg.content),
is_error: msg.isError,
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (
j < transformedMessages.length &&
transformedMessages[j].role === "toolResult"
) {
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
toolResults.push({
type: "tool_result",
tool_use_id: nextMsg.toolCallId,
content: convertContentBlocks(nextMsg.content),
is_error: nextMsg.isError,
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
// Add a single user message with all tool results
params.push({
role: "user",
content: toolResults,
});
}
}
// Add cache_control to the last user message to cache conversation history
if (cacheControl && params.length > 0) {
const lastMessage = params[params.length - 1];
if (lastMessage.role === "user") {
if (Array.isArray(lastMessage.content)) {
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
if (
lastBlock &&
(lastBlock.type === "text" ||
lastBlock.type === "image" ||
lastBlock.type === "tool_result")
) {
(lastBlock as any).cache_control = cacheControl;
}
} else if (typeof lastMessage.content === "string") {
lastMessage.content = [
{
type: "text",
text: lastMessage.content,
cache_control: cacheControl,
},
] as any;
}
}
}
return params;
}
function convertTools(
tools: Tool[],
isOAuthToken: boolean,
): Anthropic.Messages.Tool[] {
if (!tools) return [];
return tools.map((tool) => {
const jsonSchema = tool.parameters as any; // TypeBox already generates JSON Schema
return {
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: jsonSchema.properties || {},
required: jsonSchema.required || [],
},
};
});
}
function mapStopReason(
reason: Anthropic.Messages.StopReason | string,
): StopReason {
switch (reason) {
case "end_turn":
return "stop";
case "max_tokens":
return "length";
case "tool_use":
return "toolUse";
case "refusal":
return "error";
case "pause_turn": // Stop is good enough -> resubmit
return "stop";
case "stop_sequence":
return "stop"; // We don't supply stop sequences, so this should never happen
case "sensitive": // Content flagged by safety filters (not yet in SDK types)
return "error";
default:
// Handle unknown stop reasons gracefully (API may add new values)
throw new Error(`Unhandled stop reason: ${reason}`);
}
}

View file

@ -0,0 +1,297 @@
import { AzureOpenAI } from "openai";
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import {
convertResponsesMessages,
convertResponsesTools,
processResponsesStream,
} from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const DEFAULT_AZURE_API_VERSION = "v1";
const AZURE_TOOL_CALL_PROVIDERS = new Set([
"openai",
"openai-codex",
"opencode",
"azure-openai-responses",
]);
function parseDeploymentNameMap(
value: string | undefined,
): Map<string, string> {
const map = new Map<string, string>();
if (!value) return map;
for (const entry of value.split(",")) {
const trimmed = entry.trim();
if (!trimmed) continue;
const [modelId, deploymentName] = trimmed.split("=", 2);
if (!modelId || !deploymentName) continue;
map.set(modelId.trim(), deploymentName.trim());
}
return map;
}
function resolveDeploymentName(
model: Model<"azure-openai-responses">,
options?: AzureOpenAIResponsesOptions,
): string {
if (options?.azureDeploymentName) {
return options.azureDeploymentName;
}
const mappedDeployment = parseDeploymentNameMap(
process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP,
).get(model.id);
return mappedDeployment || model.id;
}
// Azure OpenAI Responses-specific options
export interface AzureOpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
azureApiVersion?: string;
azureResourceName?: string;
azureBaseUrl?: string;
azureDeploymentName?: string;
}
/**
* Generate function for Azure OpenAI Responses API
*/
export const streamAzureOpenAIResponses: StreamFunction<
"azure-openai-responses",
AzureOpenAIResponsesOptions
> = (
model: Model<"azure-openai-responses">,
context: Context,
options?: AzureOpenAIResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
// Start async processing
(async () => {
const deploymentName = resolveDeploymentName(model, options);
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "azure-openai-responses" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// Create Azure OpenAI client
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey, options);
const params = buildParams(model, context, options, deploymentName);
options?.onPayload?.(params);
const openaiStream = await client.responses.create(
params,
options?.signal ? { signal: options.signal } : undefined,
);
stream.push({ type: "start", partial: output });
await processResponsesStream(openaiStream, output, stream, model);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content)
delete (block as { index?: number }).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleAzureOpenAIResponses: StreamFunction<
"azure-openai-responses",
SimpleStreamOptions
> = (
model: Model<"azure-openai-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model)
? options?.reasoning
: clampReasoning(options?.reasoning);
return streamAzureOpenAIResponses(model, context, {
...base,
reasoningEffort,
} satisfies AzureOpenAIResponsesOptions);
};
function normalizeAzureBaseUrl(baseUrl: string): string {
return baseUrl.replace(/\/+$/, "");
}
function buildDefaultBaseUrl(resourceName: string): string {
return `https://${resourceName}.openai.azure.com/openai/v1`;
}
function resolveAzureConfig(
model: Model<"azure-openai-responses">,
options?: AzureOpenAIResponsesOptions,
): { baseUrl: string; apiVersion: string } {
const apiVersion =
options?.azureApiVersion ||
process.env.AZURE_OPENAI_API_VERSION ||
DEFAULT_AZURE_API_VERSION;
const baseUrl =
options?.azureBaseUrl?.trim() ||
process.env.AZURE_OPENAI_BASE_URL?.trim() ||
undefined;
const resourceName =
options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
let resolvedBaseUrl = baseUrl;
if (!resolvedBaseUrl && resourceName) {
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
}
if (!resolvedBaseUrl && model.baseUrl) {
resolvedBaseUrl = model.baseUrl;
}
if (!resolvedBaseUrl) {
throw new Error(
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
);
}
return {
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
apiVersion,
};
}
function createClient(
model: Model<"azure-openai-responses">,
apiKey: string,
options?: AzureOpenAIResponsesOptions,
) {
if (!apiKey) {
if (!process.env.AZURE_OPENAI_API_KEY) {
throw new Error(
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.AZURE_OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (options?.headers) {
Object.assign(headers, options.headers);
}
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
return new AzureOpenAI({
apiKey,
apiVersion,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
baseURL: baseUrl,
});
}
function buildParams(
model: Model<"azure-openai-responses">,
context: Context,
options: AzureOpenAIResponsesOptions | undefined,
deploymentName: string,
) {
const messages = convertResponsesMessages(
model,
context,
AZURE_TOOL_CALL_PROVIDERS,
);
const params: ResponseCreateParamsStreaming = {
model: deploymentName,
input: messages,
stream: true,
prompt_cache_key: options?.sessionId,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools);
}
if (model.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
if (model.name.toLowerCase().startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
messages.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
return params;
}

View file

@ -0,0 +1,37 @@
import type { Message } from "../types.js";
// Copilot expects X-Initiator to indicate whether the request is user-initiated
// or agent-initiated (e.g. follow-up after assistant/tool messages).
export function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
const last = messages[messages.length - 1];
return last && last.role !== "user" ? "agent" : "user";
}
// Copilot requires Copilot-Vision-Request header when sending images
export function hasCopilotVisionInput(messages: Message[]): boolean {
return messages.some((msg) => {
if (msg.role === "user" && Array.isArray(msg.content)) {
return msg.content.some((c) => c.type === "image");
}
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
return msg.content.some((c) => c.type === "image");
}
return false;
});
}
export function buildCopilotDynamicHeaders(params: {
messages: Message[];
hasImages: boolean;
}): Record<string, string> {
const headers: Record<string, string> = {
"X-Initiator": inferCopilotInitiator(params.messages),
"Openai-Intent": "conversation-edits",
};
if (params.hasImages) {
headers["Copilot-Vision-Request"] = "true";
}
return headers;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,373 @@
/**
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
*/
import {
type Content,
FinishReason,
FunctionCallingConfigMode,
type Part,
} from "@google/genai";
import type {
Context,
ImageContent,
Model,
StopReason,
TextContent,
Tool,
} from "../types.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { transformMessages } from "./transform-messages.js";
type GoogleApiType =
| "google-generative-ai"
| "google-gemini-cli"
| "google-vertex";
/**
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
*
* Protocol note (Gemini / Vertex AI thought signatures):
* - `thought: true` is the definitive marker for thinking content (thought summaries).
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
* used to preserve reasoning context across multi-turn interactions.
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
* indicate the part itself is thinking content.
* - For non-functionCall responses, the signature appears on the last part for context replay.
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
* do not merge/move signatures across parts.
*
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
*/
export function isThinkingPart(
part: Pick<Part, "thought" | "thoughtSignature">,
): boolean {
return part.thought === true;
}
/**
* Retain thought signatures during streaming.
*
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
* This helper preserves the last non-empty signature for the current block.
*
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
* a signature from being overwritten with `undefined` within the same streamed block.
*/
export function retainThoughtSignature(
existing: string | undefined,
incoming: string | undefined,
): string | undefined {
if (typeof incoming === "string" && incoming.length > 0) return incoming;
return existing;
}
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
// Sentinel value that tells the Gemini API to skip thought signature validation.
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
function isValidThoughtSignature(signature: string | undefined): boolean {
if (!signature) return false;
if (signature.length % 4 !== 0) return false;
return base64SignaturePattern.test(signature);
}
/**
* Only keep signatures from the same provider/model and with valid base64.
*/
function resolveThoughtSignature(
isSameProviderAndModel: boolean,
signature: string | undefined,
): string | undefined {
return isSameProviderAndModel && isValidThoughtSignature(signature)
? signature
: undefined;
}
/**
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
*/
export function requiresToolCallId(modelId: string): boolean {
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
}
/**
* Convert internal messages to Gemini Content[] format.
*/
export function convertMessages<T extends GoogleApiType>(
model: Model<T>,
context: Context,
): Content[] {
const contents: Content[] = [];
const normalizeToolCallId = (id: string): string => {
if (!requiresToolCallId(model.id)) return id;
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
};
const transformedMessages = transformMessages(
context.messages,
model,
normalizeToolCallId,
);
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
contents.push({
role: "user",
parts: [{ text: sanitizeSurrogates(msg.content) }],
});
} else {
const parts: Part[] = msg.content.map((item) => {
if (item.type === "text") {
return { text: sanitizeSurrogates(item.text) };
} else {
return {
inlineData: {
mimeType: item.mimeType,
data: item.data,
},
};
}
});
const filteredParts = !model.input.includes("image")
? parts.filter((p) => p.text !== undefined)
: parts;
if (filteredParts.length === 0) continue;
contents.push({
role: "user",
parts: filteredParts,
});
}
} else if (msg.role === "assistant") {
const parts: Part[] = [];
// Check if message is from same provider and model - only then keep thinking blocks
const isSameProviderAndModel =
msg.provider === model.provider && msg.model === model.id;
for (const block of msg.content) {
if (block.type === "text") {
// Skip empty text blocks - they can cause issues with some models (e.g. Claude via Antigravity)
if (!block.text || block.text.trim() === "") continue;
const thoughtSignature = resolveThoughtSignature(
isSameProviderAndModel,
block.textSignature,
);
parts.push({
text: sanitizeSurrogates(block.text),
...(thoughtSignature && { thoughtSignature }),
});
} else if (block.type === "thinking") {
// Skip empty thinking blocks
if (!block.thinking || block.thinking.trim() === "") continue;
// Only keep as thinking block if same provider AND same model
// Otherwise convert to plain text (no tags to avoid model mimicking them)
if (isSameProviderAndModel) {
const thoughtSignature = resolveThoughtSignature(
isSameProviderAndModel,
block.thinkingSignature,
);
parts.push({
thought: true,
text: sanitizeSurrogates(block.thinking),
...(thoughtSignature && { thoughtSignature }),
});
} else {
parts.push({
text: sanitizeSurrogates(block.thinking),
});
}
} else if (block.type === "toolCall") {
const thoughtSignature = resolveThoughtSignature(
isSameProviderAndModel,
block.thoughtSignature,
);
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
// Use the skip_thought_signature_validator sentinel for unsigned function calls
// (e.g. replayed from providers without thought signatures like Claude via Antigravity).
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
const effectiveSignature =
thoughtSignature ||
(isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
const part: Part = {
functionCall: {
name: block.name,
args: block.arguments ?? {},
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
},
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
};
parts.push(part);
}
}
if (parts.length === 0) continue;
contents.push({
role: "model",
parts,
});
} else if (msg.role === "toolResult") {
// Extract text and image content
const textContent = msg.content.filter(
(c): c is TextContent => c.type === "text",
);
const textResult = textContent.map((c) => c.text).join("\n");
const imageContent = model.input.includes("image")
? msg.content.filter((c): c is ImageContent => c.type === "image")
: [];
const hasText = textResult.length > 0;
const hasImages = imageContent.length > 0;
// Gemini 3 supports multimodal function responses with images nested inside functionResponse.parts
// See: https://ai.google.dev/gemini-api/docs/function-calling#multimodal
// Older models don't support this, so we put images in a separate user message.
const supportsMultimodalFunctionResponse = model.id.includes("gemini-3");
// Use "output" key for success, "error" key for errors as per SDK documentation
const responseValue = hasText
? sanitizeSurrogates(textResult)
: hasImages
? "(see attached image)"
: "";
const imageParts: Part[] = imageContent.map((imageBlock) => ({
inlineData: {
mimeType: imageBlock.mimeType,
data: imageBlock.data,
},
}));
const includeId = requiresToolCallId(model.id);
const functionResponsePart: Part = {
functionResponse: {
name: msg.toolName,
response: msg.isError
? { error: responseValue }
: { output: responseValue },
// Nest images inside functionResponse.parts for Gemini 3
...(hasImages &&
supportsMultimodalFunctionResponse && { parts: imageParts }),
...(includeId ? { id: msg.toolCallId } : {}),
},
};
// Cloud Code Assist API requires all function responses to be in a single user turn.
// Check if the last content is already a user turn with function responses and merge.
const lastContent = contents[contents.length - 1];
if (
lastContent?.role === "user" &&
lastContent.parts?.some((p) => p.functionResponse)
) {
lastContent.parts.push(functionResponsePart);
} else {
contents.push({
role: "user",
parts: [functionResponsePart],
});
}
// For older models, add images in a separate user message
if (hasImages && !supportsMultimodalFunctionResponse) {
contents.push({
role: "user",
parts: [{ text: "Tool result image:" }, ...imageParts],
});
}
}
}
return contents;
}
/**
* Convert tools to Gemini function declarations format.
*
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
* models, where the API translates `parameters` into Anthropic's `input_schema`.
*/
export function convertTools(
tools: Tool[],
useParameters = false,
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
if (tools.length === 0) return undefined;
return [
{
functionDeclarations: tools.map((tool) => ({
name: tool.name,
description: tool.description,
...(useParameters
? { parameters: tool.parameters }
: { parametersJsonSchema: tool.parameters }),
})),
},
];
}
/**
* Map tool choice string to Gemini FunctionCallingConfigMode.
*/
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
switch (choice) {
case "auto":
return FunctionCallingConfigMode.AUTO;
case "none":
return FunctionCallingConfigMode.NONE;
case "any":
return FunctionCallingConfigMode.ANY;
default:
return FunctionCallingConfigMode.AUTO;
}
}
/**
* Map Gemini FinishReason to our StopReason.
*/
export function mapStopReason(reason: FinishReason): StopReason {
switch (reason) {
case FinishReason.STOP:
return "stop";
case FinishReason.MAX_TOKENS:
return "length";
case FinishReason.BLOCKLIST:
case FinishReason.PROHIBITED_CONTENT:
case FinishReason.SPII:
case FinishReason.SAFETY:
case FinishReason.IMAGE_SAFETY:
case FinishReason.IMAGE_PROHIBITED_CONTENT:
case FinishReason.IMAGE_RECITATION:
case FinishReason.IMAGE_OTHER:
case FinishReason.RECITATION:
case FinishReason.FINISH_REASON_UNSPECIFIED:
case FinishReason.OTHER:
case FinishReason.LANGUAGE:
case FinishReason.MALFORMED_FUNCTION_CALL:
case FinishReason.UNEXPECTED_TOOL_CALL:
case FinishReason.NO_IMAGE:
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}
/**
* Map string finish reason to our StopReason (for raw API responses).
*/
export function mapStopReasonString(reason: string): StopReason {
switch (reason) {
case "STOP":
return "stop";
case "MAX_TOKENS":
return "length";
default:
return "error";
}
}

View file

@ -0,0 +1,529 @@
import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
ThinkingLevel,
} from "@google/genai";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
ThinkingLevel as PiThinkingLevel,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReason,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
export interface GoogleVertexOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
project?: string;
location?: string;
}
const API_VERSION = "v1";
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
MINIMAL: ThinkingLevel.MINIMAL,
LOW: ThinkingLevel.LOW,
MEDIUM: ThinkingLevel.MEDIUM,
HIGH: ThinkingLevel.HIGH,
};
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogleVertex: StreamFunction<
"google-vertex",
GoogleVertexOptions
> = (
model: Model<"google-vertex">,
context: Context,
options?: GoogleVertexOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-vertex" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const project = resolveProject(options);
const location = resolveLocation(options);
const client = createClient(model, project, location, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: undefined,
};
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
const providedId = part.functionCall.id;
const needsNewId =
!providedId ||
output.content.some(
(b) => b.type === "toolCall" && b.id === providedId,
);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments:
(part.functionCall.args as Record<string, any>) ?? {},
...(part.thoughtSignature && {
thoughtSignature: part.thoughtSignature,
}),
};
output.content.push(toolCall);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) +
(chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogleVertex: StreamFunction<
"google-vertex",
SimpleStreamOptions
> = (
model: Model<"google-vertex">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const base = buildBaseOptions(model, options, undefined);
if (!options?.reasoning) {
return streamGoogleVertex(model, context, {
...base,
thinking: { enabled: false },
} satisfies GoogleVertexOptions);
}
const effort = clampReasoning(options.reasoning)!;
const geminiModel = model as unknown as Model<"google-generative-ai">;
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
return streamGoogleVertex(model, context, {
...base,
thinking: {
enabled: true,
level: getGemini3ThinkingLevel(effort, geminiModel),
},
} satisfies GoogleVertexOptions);
}
return streamGoogleVertex(model, context, {
...base,
thinking: {
enabled: true,
budgetTokens: getGoogleBudget(
geminiModel,
effort,
options.thinkingBudgets,
),
},
} satisfies GoogleVertexOptions);
};
function createClient(
model: Model<"google-vertex">,
project: string,
location: string,
optionsHeaders?: Record<string, string>,
): GoogleGenAI {
const httpOptions: { headers?: Record<string, string> } = {};
if (model.headers || optionsHeaders) {
httpOptions.headers = { ...model.headers, ...optionsHeaders };
}
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
return new GoogleGenAI({
vertexai: true,
project,
location,
apiVersion: API_VERSION,
httpOptions: hasHttpOptions ? httpOptions : undefined,
});
}
function resolveProject(options?: GoogleVertexOptions): string {
const project =
options?.project ||
process.env.GOOGLE_CLOUD_PROJECT ||
process.env.GCLOUD_PROJECT;
if (!project) {
throw new Error(
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
);
}
return project;
}
function resolveLocation(options?: GoogleVertexOptions): string {
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
if (!location) {
throw new Error(
"Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.",
);
}
return location;
}
function buildParams(
model: Model<"google-vertex">,
context: Context,
options: GoogleVertexOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && {
systemInstruction: sanitizeSurrogates(context.systemPrompt),
}),
...(context.tools &&
context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
}
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
}
function getGemini3ThinkingLevel(
effort: ClampedThinkingLevel,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}
function getGoogleBudget(
model: Model<"google-generative-ai">,
effort: ClampedThinkingLevel,
customBudgets?: ThinkingBudgets,
): number {
if (customBudgets?.[effort] !== undefined) {
return customBudgets[effort]!;
}
if (model.id.includes("2.5-pro")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 32768,
};
return budgets[effort];
}
if (model.id.includes("2.5-flash")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
return -1;
}

View file

@ -0,0 +1,501 @@
import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
} from "@google/genai";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReason,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
export interface GoogleOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
}
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogle: StreamFunction<
"google-generative-ai",
GoogleOptions
> = (
model: Model<"google-generative-ai">,
context: Context,
options?: GoogleOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-generative-ai" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: undefined,
};
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
// Generate unique ID if not provided or if it's a duplicate
const providedId = part.functionCall.id;
const needsNewId =
!providedId ||
output.content.some(
(b) => b.type === "toolCall" && b.id === providedId,
);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments:
(part.functionCall.args as Record<string, any>) ?? {},
...(part.thoughtSignature && {
thoughtSignature: part.thoughtSignature,
}),
};
output.content.push(toolCall);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) +
(chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogle: StreamFunction<
"google-generative-ai",
SimpleStreamOptions
> = (
model: Model<"google-generative-ai">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamGoogle(model, context, {
...base,
thinking: { enabled: false },
} satisfies GoogleOptions);
}
const effort = clampReasoning(options.reasoning)!;
const googleModel = model as Model<"google-generative-ai">;
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
return streamGoogle(model, context, {
...base,
thinking: {
enabled: true,
level: getGemini3ThinkingLevel(effort, googleModel),
},
} satisfies GoogleOptions);
}
return streamGoogle(model, context, {
...base,
thinking: {
enabled: true,
budgetTokens: getGoogleBudget(
googleModel,
effort,
options.thinkingBudgets,
),
},
} satisfies GoogleOptions);
};
function createClient(
model: Model<"google-generative-ai">,
apiKey?: string,
optionsHeaders?: Record<string, string>,
): GoogleGenAI {
const httpOptions: {
baseUrl?: string;
apiVersion?: string;
headers?: Record<string, string>;
} = {};
if (model.baseUrl) {
httpOptions.baseUrl = model.baseUrl;
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
}
if (model.headers || optionsHeaders) {
httpOptions.headers = { ...model.headers, ...optionsHeaders };
}
return new GoogleGenAI({
apiKey,
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
});
}
function buildParams(
model: Model<"google-generative-ai">,
context: Context,
options: GoogleOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && {
systemInstruction: sanitizeSurrogates(context.systemPrompt),
}),
...(context.tools &&
context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
}
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
}
function getGemini3ThinkingLevel(
effort: ClampedThinkingLevel,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}
function getGoogleBudget(
model: Model<"google-generative-ai">,
effort: ClampedThinkingLevel,
customBudgets?: ThinkingBudgets,
): number {
if (customBudgets?.[effort] !== undefined) {
return customBudgets[effort]!;
}
if (model.id.includes("2.5-pro")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 32768,
};
return budgets[effort];
}
if (model.id.includes("2.5-flash")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
return -1;
}

View file

@ -0,0 +1,688 @@
import { Mistral } from "@mistralai/mistralai";
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
import type {
ChatCompletionStreamRequest,
ChatCompletionStreamRequestMessages,
CompletionEvent,
ContentChunk,
FunctionTool,
} from "@mistralai/mistralai/models/components/index.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
AssistantMessage,
Context,
Message,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { shortHash } from "../utils/hash.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
/**
* Provider-specific options for the Mistral API.
*/
export interface MistralOptions extends StreamOptions {
toolChoice?:
| "auto"
| "none"
| "any"
| "required"
| { type: "function"; function: { name: string } };
promptMode?: "reasoning";
}
/**
* Stream responses from Mistral using `chat.stream`.
*/
export const streamMistral: StreamFunction<
"mistral-conversations",
MistralOptions
> = (
model: Model<"mistral-conversations">,
context: Context,
options?: MistralOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output = createOutput(model);
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
const mistral = new Mistral({
apiKey,
serverURL: model.baseUrl,
});
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
const transformedMessages = transformMessages(
context.messages,
model,
(id) => normalizeMistralToolCallId(id),
);
const payload = buildChatPayload(
model,
context,
transformedMessages,
options,
);
options?.onPayload?.(payload);
const mistralStream = await mistral.chat.stream(
payload,
buildRequestOptions(model, options),
);
stream.push({ type: "start", partial: output });
await consumeChatStream(model, output, stream, mistralStream);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = formatMistralError(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
/**
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
*/
export const streamSimpleMistral: StreamFunction<
"mistral-conversations",
SimpleStreamOptions
> = (
model: Model<"mistral-conversations">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoning = clampReasoning(options?.reasoning);
return streamMistral(model, context, {
...base,
promptMode: model.reasoning && reasoning ? "reasoning" : undefined,
} satisfies MistralOptions);
};
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
return {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
}
function createMistralToolCallIdNormalizer(): (id: string) => string {
const idMap = new Map<string, string>();
const reverseMap = new Map<string, string>();
return (id: string): string => {
const existing = idMap.get(id);
if (existing) return existing;
let attempt = 0;
while (true) {
const candidate = deriveMistralToolCallId(id, attempt);
const owner = reverseMap.get(candidate);
if (!owner || owner === id) {
idMap.set(id, candidate);
reverseMap.set(candidate, id);
return candidate;
}
attempt++;
}
};
}
function deriveMistralToolCallId(id: string, attempt: number): string {
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH)
return normalized;
const seedBase = normalized || id;
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
return shortHash(seed)
.replace(/[^a-zA-Z0-9]/g, "")
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
}
function formatMistralError(error: unknown): string {
if (error instanceof Error) {
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
const statusCode =
typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
const bodyText =
typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
if (statusCode !== undefined && bodyText) {
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
}
if (statusCode !== undefined)
return `Mistral API error (${statusCode}): ${error.message}`;
return error.message;
}
return safeJsonStringify(error);
}
function truncateErrorText(text: string, maxChars: number): string {
if (text.length <= maxChars) return text;
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
}
function safeJsonStringify(value: unknown): string {
try {
const serialized = JSON.stringify(value);
return serialized === undefined ? String(value) : serialized;
} catch {
return String(value);
}
}
function buildRequestOptions(
model: Model<"mistral-conversations">,
options?: MistralOptions,
): RequestOptions {
const requestOptions: RequestOptions = {};
if (options?.signal) requestOptions.signal = options.signal;
requestOptions.retries = { strategy: "none" };
const headers: Record<string, string> = {};
if (model.headers) Object.assign(headers, model.headers);
if (options?.headers) Object.assign(headers, options.headers);
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
// Respect explicit caller-provided header values.
if (options?.sessionId && !headers["x-affinity"]) {
headers["x-affinity"] = options.sessionId;
}
if (Object.keys(headers).length > 0) {
requestOptions.headers = headers;
}
return requestOptions;
}
function buildChatPayload(
model: Model<"mistral-conversations">,
context: Context,
messages: Message[],
options?: MistralOptions,
): ChatCompletionStreamRequest {
const payload: ChatCompletionStreamRequest = {
model: model.id,
stream: true,
messages: toChatMessages(messages, model.input.includes("image")),
};
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
if (options?.temperature !== undefined)
payload.temperature = options.temperature;
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
if (options?.toolChoice)
payload.toolChoice = mapToolChoice(options.toolChoice);
if (options?.promptMode) payload.promptMode = options.promptMode as any;
if (context.systemPrompt) {
payload.messages.unshift({
role: "system",
content: sanitizeSurrogates(context.systemPrompt),
});
}
return payload;
}
async function consumeChatStream(
model: Model<"mistral-conversations">,
output: AssistantMessage,
stream: AssistantMessageEventStream,
mistralStream: AsyncIterable<CompletionEvent>,
): Promise<void> {
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
const toolBlocksByKey = new Map<string, number>();
const finishCurrentBlock = (block?: typeof currentBlock) => {
if (!block) return;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: block.text,
partial: output,
});
return;
}
if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: block.thinking,
partial: output,
});
}
};
for await (const event of mistralStream) {
const chunk = event.data;
if (chunk.usage) {
output.usage.input = chunk.usage.promptTokens || 0;
output.usage.output = chunk.usage.completionTokens || 0;
output.usage.cacheRead = 0;
output.usage.cacheWrite = 0;
output.usage.totalTokens =
chunk.usage.totalTokens || output.usage.input + output.usage.output;
calculateCost(model, output.usage);
}
const choice = chunk.choices[0];
if (!choice) continue;
if (choice.finishReason) {
output.stopReason = mapChatStopReason(choice.finishReason);
}
const delta = choice.delta;
if (delta.content !== null && delta.content !== undefined) {
const contentItems =
typeof delta.content === "string" ? [delta.content] : delta.content;
for (const item of contentItems) {
if (typeof item === "string") {
const textDelta = sanitizeSurrogates(item);
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.text += textDelta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: textDelta,
partial: output,
});
continue;
}
if (item.type === "thinking") {
const deltaText = item.thinking
.map((part) => ("text" in part ? part.text : ""))
.filter((text) => text.length > 0)
.join("");
const thinkingDelta = sanitizeSurrogates(deltaText);
if (!thinkingDelta) continue;
if (!currentBlock || currentBlock.type !== "thinking") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.thinking += thinkingDelta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: thinkingDelta,
partial: output,
});
continue;
}
if (item.type === "text") {
const textDelta = sanitizeSurrogates(item.text);
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.text += textDelta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: textDelta,
partial: output,
});
}
}
}
const toolCalls = delta.toolCalls || [];
for (const toolCall of toolCalls) {
if (currentBlock) {
finishCurrentBlock(currentBlock);
currentBlock = null;
}
const callId =
toolCall.id && toolCall.id !== "null"
? toolCall.id
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
const key = `${callId}:${toolCall.index || 0}`;
const existingIndex = toolBlocksByKey.get(key);
let block: (ToolCall & { partialArgs?: string }) | undefined;
if (existingIndex !== undefined) {
const existing = output.content[existingIndex];
if (existing?.type === "toolCall") {
block = existing as ToolCall & { partialArgs?: string };
}
}
if (!block) {
block = {
type: "toolCall",
id: callId,
name: toolCall.function.name,
arguments: {},
partialArgs: "",
};
output.content.push(block);
toolBlocksByKey.set(key, output.content.length - 1);
stream.push({
type: "toolcall_start",
contentIndex: output.content.length - 1,
partial: output,
});
}
const argsDelta =
typeof toolCall.function.arguments === "string"
? toolCall.function.arguments
: JSON.stringify(toolCall.function.arguments || {});
block.partialArgs = (block.partialArgs || "") + argsDelta;
block.arguments = parseStreamingJson<Record<string, unknown>>(
block.partialArgs,
);
stream.push({
type: "toolcall_delta",
contentIndex: toolBlocksByKey.get(key)!,
delta: argsDelta,
partial: output,
});
}
}
finishCurrentBlock(currentBlock);
for (const index of toolBlocksByKey.values()) {
const block = output.content[index];
if (block.type !== "toolCall") continue;
const toolBlock = block as ToolCall & { partialArgs?: string };
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(
toolBlock.partialArgs,
);
delete toolBlock.partialArgs;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: toolBlock,
partial: output,
});
}
}
function toFunctionTools(
tools: Tool[],
): Array<FunctionTool & { type: "function" }> {
return tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as unknown as Record<string, unknown>,
strict: false,
},
}));
}
function toChatMessages(
messages: Message[],
supportsImages: boolean,
): ChatCompletionStreamRequestMessages[] {
const result: ChatCompletionStreamRequestMessages[] = [];
for (const msg of messages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
continue;
}
const hadImages = msg.content.some((item) => item.type === "image");
const content: ContentChunk[] = msg.content
.filter((item) => item.type === "text" || supportsImages)
.map((item) => {
if (item.type === "text")
return { type: "text", text: sanitizeSurrogates(item.text) };
return {
type: "image_url",
imageUrl: `data:${item.mimeType};base64,${item.data}`,
};
});
if (content.length > 0) {
result.push({ role: "user", content });
continue;
}
if (hadImages && !supportsImages) {
result.push({
role: "user",
content: "(image omitted: model does not support images)",
});
}
continue;
}
if (msg.role === "assistant") {
const contentParts: ContentChunk[] = [];
const toolCalls: Array<{
id: string;
type: "function";
function: { name: string; arguments: string };
}> = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length > 0) {
contentParts.push({
type: "text",
text: sanitizeSurrogates(block.text),
});
}
continue;
}
if (block.type === "thinking") {
if (block.thinking.trim().length > 0) {
contentParts.push({
type: "thinking",
thinking: [
{ type: "text", text: sanitizeSurrogates(block.thinking) },
],
});
}
continue;
}
toolCalls.push({
id: block.id,
type: "function",
function: {
name: block.name,
arguments: JSON.stringify(block.arguments || {}),
},
});
}
const assistantMessage: ChatCompletionStreamRequestMessages = {
role: "assistant",
};
if (contentParts.length > 0) assistantMessage.content = contentParts;
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
if (contentParts.length > 0 || toolCalls.length > 0)
result.push(assistantMessage);
continue;
}
const toolContent: ContentChunk[] = [];
const textResult = msg.content
.filter((part) => part.type === "text")
.map((part) =>
part.type === "text" ? sanitizeSurrogates(part.text) : "",
)
.join("\n");
const hasImages = msg.content.some((part) => part.type === "image");
const toolText = buildToolResultText(
textResult,
hasImages,
supportsImages,
msg.isError,
);
toolContent.push({ type: "text", text: toolText });
for (const part of msg.content) {
if (!supportsImages) continue;
if (part.type !== "image") continue;
toolContent.push({
type: "image_url",
imageUrl: `data:${part.mimeType};base64,${part.data}`,
});
}
result.push({
role: "tool",
toolCallId: msg.toolCallId,
name: msg.toolName,
content: toolContent,
});
}
return result;
}
function buildToolResultText(
text: string,
hasImages: boolean,
supportsImages: boolean,
isError: boolean,
): string {
const trimmed = text.trim();
const errorPrefix = isError ? "[tool error] " : "";
if (trimmed.length > 0) {
const imageSuffix =
hasImages && !supportsImages
? "\n[tool image omitted: model does not support images]"
: "";
return `${errorPrefix}${trimmed}${imageSuffix}`;
}
if (hasImages) {
if (supportsImages) {
return isError
? "[tool error] (see attached image)"
: "(see attached image)";
}
return isError
? "[tool error] (image omitted: model does not support images)"
: "(image omitted: model does not support images)";
}
return isError ? "[tool error] (no tool output)" : "(no tool output)";
}
function mapToolChoice(
choice: MistralOptions["toolChoice"],
):
| "auto"
| "none"
| "any"
| "required"
| { type: "function"; function: { name: string } }
| undefined {
if (!choice) return undefined;
if (
choice === "auto" ||
choice === "none" ||
choice === "any" ||
choice === "required"
) {
return choice as any;
}
return {
type: "function",
function: { name: choice.function.name },
};
}
function mapChatStopReason(reason: string | null): StopReason {
if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
case "model_length":
return "length";
case "tool_calls":
return "toolUse";
case "error":
return "error";
default:
return "stop";
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,949 @@
import OpenAI from "openai";
import type {
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionContentPartImage,
ChatCompletionContentPartText,
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
} from "openai/resources/chat/completions.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost, supportsXhigh } from "../models.js";
import type {
AssistantMessage,
Context,
Message,
Model,
OpenAICompletionsCompat,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
buildCopilotDynamicHeaders,
hasCopilotVisionInput,
} from "./github-copilot-headers.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
/**
* Check if conversation messages contain tool calls or tool results.
* This is needed because Anthropic (via proxy) requires the tools param
* to be present when messages include tool_calls or tool role messages.
*/
function hasToolHistory(messages: Message[]): boolean {
for (const msg of messages) {
if (msg.role === "toolResult") {
return true;
}
if (msg.role === "assistant") {
if (msg.content.some((block) => block.type === "toolCall")) {
return true;
}
}
}
return false;
}
export interface OpenAICompletionsOptions extends StreamOptions {
toolChoice?:
| "auto"
| "none"
| "required"
| { type: "function"; function: { name: string } };
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
}
export const streamOpenAICompletions: StreamFunction<
"openai-completions",
OpenAICompletionsOptions
> = (
model: Model<"openai-completions">,
context: Context,
options?: OpenAICompletionsOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, context, apiKey, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const openaiStream = await client.chat.completions.create(params, {
signal: options?.signal,
});
stream.push({ type: "start", partial: output });
let currentBlock:
| TextContent
| ThinkingContent
| (ToolCall & { partialArgs?: string })
| null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
const finishCurrentBlock = (block?: typeof currentBlock) => {
if (block) {
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
block.arguments = parseStreamingJson(block.partialArgs);
delete block.partialArgs;
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall: block,
partial: output,
});
}
}
};
for await (const chunk of openaiStream) {
if (chunk.usage) {
const cachedTokens =
chunk.usage.prompt_tokens_details?.cached_tokens || 0;
const reasoningTokens =
chunk.usage.completion_tokens_details?.reasoning_tokens || 0;
const input = (chunk.usage.prompt_tokens || 0) - cachedTokens;
const outputTokens =
(chunk.usage.completion_tokens || 0) + reasoningTokens;
output.usage = {
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
input,
output: outputTokens,
cacheRead: cachedTokens,
cacheWrite: 0,
// Compute totalTokens ourselves since we add reasoning_tokens to output
// and some providers (e.g., Groq) don't include them in total_tokens
totalTokens: input + outputTokens + cachedTokens,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
const choice = chunk.choices?.[0];
if (!choice) continue;
if (choice.finish_reason) {
output.stopReason = mapStopReason(choice.finish_reason);
}
if (choice.delta) {
if (
choice.delta.content !== null &&
choice.delta.content !== undefined &&
choice.delta.content.length > 0
) {
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock.type === "text") {
currentBlock.text += choice.delta.content;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: choice.delta.content,
partial: output,
});
}
}
// Some endpoints return reasoning in reasoning_content (llama.cpp),
// or reasoning (other openai compatible endpoints)
// Use the first non-empty reasoning field to avoid duplication
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
const reasoningFields = [
"reasoning_content",
"reasoning",
"reasoning_text",
];
let foundReasoningField: string | null = null;
for (const field of reasoningFields) {
if (
(choice.delta as any)[field] !== null &&
(choice.delta as any)[field] !== undefined &&
(choice.delta as any)[field].length > 0
) {
if (!foundReasoningField) {
foundReasoningField = field;
break;
}
}
}
if (foundReasoningField) {
if (!currentBlock || currentBlock.type !== "thinking") {
finishCurrentBlock(currentBlock);
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: foundReasoningField,
};
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock.type === "thinking") {
const delta = (choice.delta as any)[foundReasoningField];
currentBlock.thinking += delta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta,
partial: output,
});
}
}
if (choice?.delta?.tool_calls) {
for (const toolCall of choice.delta.tool_calls) {
if (
!currentBlock ||
currentBlock.type !== "toolCall" ||
(toolCall.id && currentBlock.id !== toolCall.id)
) {
finishCurrentBlock(currentBlock);
currentBlock = {
type: "toolCall",
id: toolCall.id || "",
name: toolCall.function?.name || "",
arguments: {},
partialArgs: "",
};
output.content.push(currentBlock);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock.type === "toolCall") {
if (toolCall.id) currentBlock.id = toolCall.id;
if (toolCall.function?.name)
currentBlock.name = toolCall.function.name;
let delta = "";
if (toolCall.function?.arguments) {
delta = toolCall.function.arguments;
currentBlock.partialArgs += toolCall.function.arguments;
currentBlock.arguments = parseStreamingJson(
currentBlock.partialArgs,
);
}
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta,
partial: output,
});
}
}
}
const reasoningDetails = (choice.delta as any).reasoning_details;
if (reasoningDetails && Array.isArray(reasoningDetails)) {
for (const detail of reasoningDetails) {
if (
detail.type === "reasoning.encrypted" &&
detail.id &&
detail.data
) {
const matchingToolCall = output.content.find(
(b) => b.type === "toolCall" && b.id === detail.id,
) as ToolCall | undefined;
if (matchingToolCall) {
matchingToolCall.thoughtSignature = JSON.stringify(detail);
}
}
}
}
}
}
finishCurrentBlock(currentBlock);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
// Some providers via OpenRouter give additional information in this field.
const rawMetadata = (error as any)?.error?.metadata?.raw;
if (rawMetadata) output.errorMessage += `\n${rawMetadata}`;
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAICompletions: StreamFunction<
"openai-completions",
SimpleStreamOptions
> = (
model: Model<"openai-completions">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model)
? options?.reasoning
: clampReasoning(options?.reasoning);
const toolChoice = (options as OpenAICompletionsOptions | undefined)
?.toolChoice;
return streamOpenAICompletions(model, context, {
...base,
reasoningEffort,
toolChoice,
} satisfies OpenAICompletionsOptions);
};
function createClient(
model: Model<"openai-completions">,
context: Context,
apiKey?: string,
optionsHeaders?: Record<string, string>,
) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
const copilotHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
Object.assign(headers, copilotHeaders);
}
// Merge options headers last so they can override defaults
if (optionsHeaders) {
Object.assign(headers, optionsHeaders);
}
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
});
}
function buildParams(
model: Model<"openai-completions">,
context: Context,
options?: OpenAICompletionsOptions,
) {
const compat = getCompat(model);
const messages = convertMessages(model, context, compat);
maybeAddOpenRouterAnthropicCacheControl(model, messages);
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: model.id,
messages,
stream: true,
};
if (compat.supportsUsageInStreaming !== false) {
(params as any).stream_options = { include_usage: true };
}
if (compat.supportsStore) {
params.store = false;
}
if (options?.maxTokens) {
if (compat.maxTokensField === "max_tokens") {
(params as any).max_tokens = options.maxTokens;
} else {
params.max_completion_tokens = options.maxTokens;
}
}
if (options?.temperature !== undefined) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools, compat);
} else if (hasToolHistory(context.messages)) {
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
params.tools = [];
}
if (options?.toolChoice) {
params.tool_choice = options.toolChoice;
}
if (
(compat.thinkingFormat === "zai" || compat.thinkingFormat === "qwen") &&
model.reasoning
) {
// Both Z.ai and Qwen use enable_thinking: boolean
(params as any).enable_thinking = !!options?.reasoningEffort;
} else if (
options?.reasoningEffort &&
model.reasoning &&
compat.supportsReasoningEffort
) {
// OpenAI-style reasoning_effort
(params as any).reasoning_effort = mapReasoningEffort(
options.reasoningEffort,
compat.reasoningEffortMap,
);
}
// OpenRouter provider routing preferences
if (
model.baseUrl.includes("openrouter.ai") &&
model.compat?.openRouterRouting
) {
(params as any).provider = model.compat.openRouterRouting;
}
// Vercel AI Gateway provider routing preferences
if (
model.baseUrl.includes("ai-gateway.vercel.sh") &&
model.compat?.vercelGatewayRouting
) {
const routing = model.compat.vercelGatewayRouting;
if (routing.only || routing.order) {
const gatewayOptions: Record<string, string[]> = {};
if (routing.only) gatewayOptions.only = routing.only;
if (routing.order) gatewayOptions.order = routing.order;
(params as any).providerOptions = { gateway: gatewayOptions };
}
}
return params;
}
function mapReasoningEffort(
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
reasoningEffortMap: Partial<
Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>
>,
): string {
return reasoningEffortMap[effort] ?? effort;
}
function maybeAddOpenRouterAnthropicCacheControl(
model: Model<"openai-completions">,
messages: ChatCompletionMessageParam[],
): void {
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/"))
return;
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
// on the last user/assistant message (walking backwards until we find text content).
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i];
if (msg.role !== "user" && msg.role !== "assistant") continue;
const content = msg.content;
if (typeof content === "string") {
msg.content = [
Object.assign(
{ type: "text" as const, text: content },
{ cache_control: { type: "ephemeral" } },
),
];
return;
}
if (!Array.isArray(content)) continue;
// Find last text part and add cache_control
for (let j = content.length - 1; j >= 0; j--) {
const part = content[j];
if (part?.type === "text") {
Object.assign(part, { cache_control: { type: "ephemeral" } });
return;
}
}
}
}
export function convertMessages(
model: Model<"openai-completions">,
context: Context,
compat: Required<OpenAICompletionsCompat>,
): ChatCompletionMessageParam[] {
const params: ChatCompletionMessageParam[] = [];
const normalizeToolCallId = (id: string): string => {
// Handle pipe-separated IDs from OpenAI Responses API
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
// These come from providers like github-copilot, openai-codex, opencode
// Extract just the call_id part and normalize it
if (id.includes("|")) {
const [callId] = id.split("|");
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
}
if (model.provider === "openai")
return id.length > 40 ? id.slice(0, 40) : id;
return id;
};
const transformedMessages = transformMessages(context.messages, model, (id) =>
normalizeToolCallId(id),
);
if (context.systemPrompt) {
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
const role = useDeveloperRole ? "developer" : "system";
params.push({
role: role,
content: sanitizeSurrogates(context.systemPrompt),
});
}
let lastRole: string | null = null;
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
// Some providers don't allow user messages directly after tool results
// Insert a synthetic assistant message to bridge the gap
if (
compat.requiresAssistantAfterToolResult &&
lastRole === "toolResult" &&
msg.role === "user"
) {
params.push({
role: "assistant",
content: "I have processed the tool results.",
});
}
if (msg.role === "user") {
if (typeof msg.content === "string") {
params.push({
role: "user",
content: sanitizeSurrogates(msg.content),
});
} else {
const content: ChatCompletionContentPart[] = msg.content.map(
(item): ChatCompletionContentPart => {
if (item.type === "text") {
return {
type: "text",
text: sanitizeSurrogates(item.text),
} satisfies ChatCompletionContentPartText;
} else {
return {
type: "image_url",
image_url: {
url: `data:${item.mimeType};base64,${item.data}`,
},
} satisfies ChatCompletionContentPartImage;
}
},
);
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "image_url")
: content;
if (filteredContent.length === 0) continue;
params.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
// Some providers don't accept null content, use empty string instead
const assistantMsg: ChatCompletionAssistantMessageParam = {
role: "assistant",
content: compat.requiresAssistantAfterToolResult ? "" : null,
};
const textBlocks = msg.content.filter(
(b) => b.type === "text",
) as TextContent[];
// Filter out empty text blocks to avoid API validation errors
const nonEmptyTextBlocks = textBlocks.filter(
(b) => b.text && b.text.trim().length > 0,
);
if (nonEmptyTextBlocks.length > 0) {
// GitHub Copilot requires assistant content as a string, not an array.
// Sending as array causes Claude models to re-answer all previous prompts.
if (model.provider === "github-copilot") {
assistantMsg.content = nonEmptyTextBlocks
.map((b) => sanitizeSurrogates(b.text))
.join("");
} else {
assistantMsg.content = nonEmptyTextBlocks.map((b) => {
return { type: "text", text: sanitizeSurrogates(b.text) };
});
}
}
// Handle thinking blocks
const thinkingBlocks = msg.content.filter(
(b) => b.type === "thinking",
) as ThinkingContent[];
// Filter out empty thinking blocks to avoid API validation errors
const nonEmptyThinkingBlocks = thinkingBlocks.filter(
(b) => b.thinking && b.thinking.trim().length > 0,
);
if (nonEmptyThinkingBlocks.length > 0) {
if (compat.requiresThinkingAsText) {
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
const thinkingText = nonEmptyThinkingBlocks
.map((b) => b.thinking)
.join("\n\n");
const textContent = assistantMsg.content as Array<{
type: "text";
text: string;
}> | null;
if (textContent) {
textContent.unshift({ type: "text", text: thinkingText });
} else {
assistantMsg.content = [{ type: "text", text: thinkingText }];
}
} else {
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
if (signature && signature.length > 0) {
(assistantMsg as any)[signature] = nonEmptyThinkingBlocks
.map((b) => b.thinking)
.join("\n");
}
}
}
const toolCalls = msg.content.filter(
(b) => b.type === "toolCall",
) as ToolCall[];
if (toolCalls.length > 0) {
assistantMsg.tool_calls = toolCalls.map((tc) => ({
id: tc.id,
type: "function" as const,
function: {
name: tc.name,
arguments: JSON.stringify(tc.arguments),
},
}));
const reasoningDetails = toolCalls
.filter((tc) => tc.thoughtSignature)
.map((tc) => {
try {
return JSON.parse(tc.thoughtSignature!);
} catch {
return null;
}
})
.filter(Boolean);
if (reasoningDetails.length > 0) {
(assistantMsg as any).reasoning_details = reasoningDetails;
}
}
// Skip assistant messages that have no content and no tool calls.
// Some providers require "either content or tool_calls, but not none".
// Other providers also don't accept empty assistant messages.
// This handles aborted assistant responses that got no content.
const content = assistantMsg.content;
const hasContent =
content !== null &&
content !== undefined &&
(typeof content === "string" ? content.length > 0 : content.length > 0);
if (!hasContent && !assistantMsg.tool_calls) {
continue;
}
params.push(assistantMsg);
} else if (msg.role === "toolResult") {
const imageBlocks: Array<{
type: "image_url";
image_url: { url: string };
}> = [];
let j = i;
for (
;
j < transformedMessages.length &&
transformedMessages[j].role === "toolResult";
j++
) {
const toolMsg = transformedMessages[j] as ToolResultMessage;
// Extract text and image content
const textResult = toolMsg.content
.filter((c) => c.type === "text")
.map((c) => (c as any).text)
.join("\n");
const hasImages = toolMsg.content.some((c) => c.type === "image");
// Always send tool result with text (or placeholder if only images)
const hasText = textResult.length > 0;
// Some providers require the 'name' field in tool results
const toolResultMsg: ChatCompletionToolMessageParam = {
role: "tool",
content: sanitizeSurrogates(
hasText ? textResult : "(see attached image)",
),
tool_call_id: toolMsg.toolCallId,
};
if (compat.requiresToolResultName && toolMsg.toolName) {
(toolResultMsg as any).name = toolMsg.toolName;
}
params.push(toolResultMsg);
if (hasImages && model.input.includes("image")) {
for (const block of toolMsg.content) {
if (block.type === "image") {
imageBlocks.push({
type: "image_url",
image_url: {
url: `data:${(block as any).mimeType};base64,${(block as any).data}`,
},
});
}
}
}
}
i = j - 1;
if (imageBlocks.length > 0) {
if (compat.requiresAssistantAfterToolResult) {
params.push({
role: "assistant",
content: "I have processed the tool results.",
});
}
params.push({
role: "user",
content: [
{
type: "text",
text: "Attached image(s) from tool result:",
},
...imageBlocks,
],
});
lastRole = "user";
} else {
lastRole = "toolResult";
}
continue;
}
lastRole = msg.role;
}
return params;
}
function convertTools(
tools: Tool[],
compat: Required<OpenAICompletionsCompat>,
): OpenAI.Chat.Completions.ChatCompletionTool[] {
return tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
// Only include strict if provider supports it. Some reject unknown fields.
...(compat.supportsStrictMode !== false && { strict: false }),
},
}));
}
function mapStopReason(
reason: ChatCompletionChunk.Choice["finish_reason"],
): StopReason {
if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
return "length";
case "function_call":
case "tool_calls":
return "toolUse";
case "content_filter":
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}
/**
* Detect compatibility settings from provider and baseUrl for known providers.
* Provider takes precedence over URL-based detection since it's explicitly configured.
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
*/
function detectCompat(
model: Model<"openai-completions">,
): Required<OpenAICompletionsCompat> {
const provider = model.provider;
const baseUrl = model.baseUrl;
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
const isNonStandard =
provider === "cerebras" ||
baseUrl.includes("cerebras.ai") ||
provider === "xai" ||
baseUrl.includes("api.x.ai") ||
baseUrl.includes("chutes.ai") ||
baseUrl.includes("deepseek.com") ||
isZai ||
provider === "opencode" ||
baseUrl.includes("opencode.ai");
const useMaxTokens = baseUrl.includes("chutes.ai");
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
const reasoningEffortMap =
isGroq && model.id === "qwen/qwen3-32b"
? {
minimal: "default",
low: "default",
medium: "default",
high: "default",
xhigh: "default",
}
: {};
return {
supportsStore: !isNonStandard,
supportsDeveloperRole: !isNonStandard,
supportsReasoningEffort: !isGrok && !isZai,
reasoningEffortMap,
supportsUsageInStreaming: true,
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
requiresToolResultName: false,
requiresAssistantAfterToolResult: false,
requiresThinkingAsText: false,
thinkingFormat: isZai ? "zai" : "openai",
openRouterRouting: {},
vercelGatewayRouting: {},
supportsStrictMode: true,
};
}
/**
* Get resolved compatibility settings for a model.
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
*/
function getCompat(
model: Model<"openai-completions">,
): Required<OpenAICompletionsCompat> {
const detected = detectCompat(model);
if (!model.compat) return detected;
return {
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
supportsDeveloperRole:
model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
supportsReasoningEffort:
model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
reasoningEffortMap:
model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
supportsUsageInStreaming:
model.compat.supportsUsageInStreaming ??
detected.supportsUsageInStreaming,
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
requiresToolResultName:
model.compat.requiresToolResultName ?? detected.requiresToolResultName,
requiresAssistantAfterToolResult:
model.compat.requiresAssistantAfterToolResult ??
detected.requiresAssistantAfterToolResult,
requiresThinkingAsText:
model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
openRouterRouting: model.compat.openRouterRouting ?? {},
vercelGatewayRouting:
model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
supportsStrictMode:
model.compat.supportsStrictMode ?? detected.supportsStrictMode,
};
}

View file

@ -0,0 +1,583 @@
import type OpenAI from "openai";
import type {
Tool as OpenAITool,
ResponseCreateParamsStreaming,
ResponseFunctionToolCall,
ResponseInput,
ResponseInputContent,
ResponseInputImage,
ResponseInputText,
ResponseOutputMessage,
ResponseReasoningItem,
ResponseStreamEvent,
} from "openai/resources/responses/responses.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
ImageContent,
Model,
StopReason,
TextContent,
TextSignatureV1,
ThinkingContent,
Tool,
ToolCall,
Usage,
} from "../types.js";
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
import { shortHash } from "../utils/hash.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { transformMessages } from "./transform-messages.js";
// =============================================================================
// Utilities
// =============================================================================
function encodeTextSignatureV1(
id: string,
phase?: TextSignatureV1["phase"],
): string {
const payload: TextSignatureV1 = { v: 1, id };
if (phase) payload.phase = phase;
return JSON.stringify(payload);
}
function parseTextSignature(
signature: string | undefined,
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
if (!signature) return undefined;
if (signature.startsWith("{")) {
try {
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
if (parsed.v === 1 && typeof parsed.id === "string") {
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
return { id: parsed.id, phase: parsed.phase };
}
return { id: parsed.id };
}
} catch {
// Fall through to legacy plain-string handling.
}
}
return { id: signature };
}
export interface OpenAIResponsesStreamOptions {
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
applyServiceTierPricing?: (
usage: Usage,
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
) => void;
}
export interface ConvertResponsesMessagesOptions {
includeSystemPrompt?: boolean;
}
export interface ConvertResponsesToolsOptions {
strict?: boolean | null;
}
// =============================================================================
// Message conversion
// =============================================================================
export function convertResponsesMessages<TApi extends Api>(
model: Model<TApi>,
context: Context,
allowedToolCallProviders: ReadonlySet<string>,
options?: ConvertResponsesMessagesOptions,
): ResponseInput {
const messages: ResponseInput = [];
const normalizeToolCallId = (id: string): string => {
if (!allowedToolCallProviders.has(model.provider)) return id;
if (!id.includes("|")) return id;
const [callId, itemId] = id.split("|");
const sanitizedCallId = callId.replace(/[^a-zA-Z0-9_-]/g, "_");
let sanitizedItemId = itemId.replace(/[^a-zA-Z0-9_-]/g, "_");
// OpenAI Responses API requires item id to start with "fc"
if (!sanitizedItemId.startsWith("fc")) {
sanitizedItemId = `fc_${sanitizedItemId}`;
}
// Truncate to 64 chars and strip trailing underscores (OpenAI Codex rejects them)
let normalizedCallId =
sanitizedCallId.length > 64
? sanitizedCallId.slice(0, 64)
: sanitizedCallId;
let normalizedItemId =
sanitizedItemId.length > 64
? sanitizedItemId.slice(0, 64)
: sanitizedItemId;
normalizedCallId = normalizedCallId.replace(/_+$/, "");
normalizedItemId = normalizedItemId.replace(/_+$/, "");
return `${normalizedCallId}|${normalizedItemId}`;
};
const transformedMessages = transformMessages(
context.messages,
model,
normalizeToolCallId,
);
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
if (includeSystemPrompt && context.systemPrompt) {
const role = model.reasoning ? "developer" : "system";
messages.push({
role,
content: sanitizeSurrogates(context.systemPrompt),
});
}
let msgIndex = 0;
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
messages.push({
role: "user",
content: [
{ type: "input_text", text: sanitizeSurrogates(msg.content) },
],
});
} else {
const content: ResponseInputContent[] = msg.content.map(
(item): ResponseInputContent => {
if (item.type === "text") {
return {
type: "input_text",
text: sanitizeSurrogates(item.text),
} satisfies ResponseInputText;
}
return {
type: "input_image",
detail: "auto",
image_url: `data:${item.mimeType};base64,${item.data}`,
} satisfies ResponseInputImage;
},
);
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "input_image")
: content;
if (filteredContent.length === 0) continue;
messages.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
const output: ResponseInput = [];
const assistantMsg = msg as AssistantMessage;
const isDifferentModel =
assistantMsg.model !== model.id &&
assistantMsg.provider === model.provider &&
assistantMsg.api === model.api;
for (const block of msg.content) {
if (block.type === "thinking") {
if (block.thinking.trim().length === 0) continue;
if (block.thinkingSignature) {
const reasoningItem = JSON.parse(
block.thinkingSignature,
) as ResponseReasoningItem;
output.push(reasoningItem);
}
} else if (block.type === "text") {
const textBlock = block as TextContent;
const parsedSignature = parseTextSignature(textBlock.textSignature);
// OpenAI requires id to be max 64 characters
let msgId = parsedSignature?.id;
if (!msgId) {
msgId = `msg_${msgIndex}`;
} else if (msgId.length > 64) {
msgId = `msg_${shortHash(msgId)}`;
}
output.push({
type: "message",
role: "assistant",
content: [
{
type: "output_text",
text: sanitizeSurrogates(textBlock.text),
annotations: [],
},
],
status: "completed",
id: msgId,
phase: parsedSignature?.phase,
} satisfies ResponseOutputMessage);
} else if (block.type === "toolCall") {
const toolCall = block as ToolCall;
const [callId, itemIdRaw] = toolCall.id.split("|");
let itemId: string | undefined = itemIdRaw;
// For different-model messages, set id to undefined to avoid pairing validation.
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
// By omitting the id, we avoid triggering that validation (like cross-provider does).
if (isDifferentModel && itemId?.startsWith("fc_")) {
itemId = undefined;
}
output.push({
type: "function_call",
id: itemId,
call_id: callId,
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
});
}
}
if (output.length === 0) continue;
messages.push(...output);
} else if (msg.role === "toolResult") {
// Extract text and image content
const textResult = msg.content
.filter((c): c is TextContent => c.type === "text")
.map((c) => c.text)
.join("\n");
const hasImages = msg.content.some(
(c): c is ImageContent => c.type === "image",
);
// Always send function_call_output with text (or placeholder if only images)
const hasText = textResult.length > 0;
const [callId] = msg.toolCallId.split("|");
messages.push({
type: "function_call_output",
call_id: callId,
output: sanitizeSurrogates(
hasText ? textResult : "(see attached image)",
),
});
// If there are images and model supports them, send a follow-up user message with images
if (hasImages && model.input.includes("image")) {
const contentParts: ResponseInputContent[] = [];
// Add text prefix
contentParts.push({
type: "input_text",
text: "Attached image(s) from tool result:",
} satisfies ResponseInputText);
// Add images
for (const block of msg.content) {
if (block.type === "image") {
contentParts.push({
type: "input_image",
detail: "auto",
image_url: `data:${block.mimeType};base64,${block.data}`,
} satisfies ResponseInputImage);
}
}
messages.push({
role: "user",
content: contentParts,
});
}
}
msgIndex++;
}
return messages;
}
// =============================================================================
// Tool conversion
// =============================================================================
export function convertResponsesTools(
tools: Tool[],
options?: ConvertResponsesToolsOptions,
): OpenAITool[] {
const strict = options?.strict === undefined ? false : options.strict;
return tools.map((tool) => ({
type: "function",
name: tool.name,
description: tool.description,
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
strict,
}));
}
// =============================================================================
// Stream processing
// =============================================================================
export async function processResponsesStream<TApi extends Api>(
openaiStream: AsyncIterable<ResponseStreamEvent>,
output: AssistantMessage,
stream: AssistantMessageEventStream,
model: Model<TApi>,
options?: OpenAIResponsesStreamOptions,
): Promise<void> {
let currentItem:
| ResponseReasoningItem
| ResponseOutputMessage
| ResponseFunctionToolCall
| null = null;
let currentBlock:
| ThinkingContent
| TextContent
| (ToolCall & { partialJson: string })
| null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const event of openaiStream) {
if (event.type === "response.output_item.added") {
const item = event.item;
if (item.type === "reasoning") {
currentItem = item;
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else if (item.type === "message") {
currentItem = item;
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
} else if (item.type === "function_call") {
currentItem = item;
currentBlock = {
type: "toolCall",
id: `${item.call_id}|${item.id}`,
name: item.name,
arguments: {},
partialJson: item.arguments || "",
};
output.content.push(currentBlock);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
}
} else if (event.type === "response.reasoning_summary_part.added") {
if (currentItem && currentItem.type === "reasoning") {
currentItem.summary = currentItem.summary || [];
currentItem.summary.push(event.part);
}
} else if (event.type === "response.reasoning_summary_text.delta") {
if (
currentItem?.type === "reasoning" &&
currentBlock?.type === "thinking"
) {
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
currentBlock.thinking += event.delta;
lastPart.text += event.delta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.reasoning_summary_part.done") {
if (
currentItem?.type === "reasoning" &&
currentBlock?.type === "thinking"
) {
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
currentBlock.thinking += "\n\n";
lastPart.text += "\n\n";
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: "\n\n",
partial: output,
});
}
}
} else if (event.type === "response.content_part.added") {
if (currentItem?.type === "message") {
currentItem.content = currentItem.content || [];
// Filter out ReasoningText, only accept output_text and refusal
if (
event.part.type === "output_text" ||
event.part.type === "refusal"
) {
currentItem.content.push(event.part);
}
}
} else if (event.type === "response.output_text.delta") {
if (currentItem?.type === "message" && currentBlock?.type === "text") {
if (!currentItem.content || currentItem.content.length === 0) {
continue;
}
const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart?.type === "output_text") {
currentBlock.text += event.delta;
lastPart.text += event.delta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.refusal.delta") {
if (currentItem?.type === "message" && currentBlock?.type === "text") {
if (!currentItem.content || currentItem.content.length === 0) {
continue;
}
const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart?.type === "refusal") {
currentBlock.text += event.delta;
lastPart.refusal += event.delta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.function_call_arguments.delta") {
if (
currentItem?.type === "function_call" &&
currentBlock?.type === "toolCall"
) {
currentBlock.partialJson += event.delta;
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
} else if (event.type === "response.function_call_arguments.done") {
if (
currentItem?.type === "function_call" &&
currentBlock?.type === "toolCall"
) {
currentBlock.partialJson = event.arguments;
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
}
} else if (event.type === "response.output_item.done") {
const item = event.item;
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
currentBlock.thinking =
item.summary?.map((s) => s.text).join("\n\n") || "";
currentBlock.thinkingSignature = JSON.stringify(item);
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
currentBlock = null;
} else if (item.type === "message" && currentBlock?.type === "text") {
currentBlock.text = item.content
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
.join("");
currentBlock.textSignature = encodeTextSignatureV1(
item.id,
item.phase ?? undefined,
);
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
currentBlock = null;
} else if (item.type === "function_call") {
const args =
currentBlock?.type === "toolCall" && currentBlock.partialJson
? parseStreamingJson(currentBlock.partialJson)
: parseStreamingJson(item.arguments || "{}");
const toolCall: ToolCall = {
type: "toolCall",
id: `${item.call_id}|${item.id}`,
name: item.name,
arguments: args,
};
currentBlock = null;
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
} else if (event.type === "response.completed") {
const response = event.response;
if (response?.usage) {
const cachedTokens =
response.usage.input_tokens_details?.cached_tokens || 0;
output.usage = {
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
input: (response.usage.input_tokens || 0) - cachedTokens,
output: response.usage.output_tokens || 0,
cacheRead: cachedTokens,
cacheWrite: 0,
totalTokens: response.usage.total_tokens || 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
calculateCost(model, output.usage);
if (options?.applyServiceTierPricing) {
const serviceTier = response?.service_tier ?? options.serviceTier;
options.applyServiceTierPricing(output.usage, serviceTier);
}
// Map status to stop reason
output.stopReason = mapStopReason(response?.status);
if (
output.content.some((b) => b.type === "toolCall") &&
output.stopReason === "stop"
) {
output.stopReason = "toolUse";
}
} else if (event.type === "error") {
throw new Error(
`Error Code ${event.code}: ${event.message}` || "Unknown error",
);
} else if (event.type === "response.failed") {
throw new Error("Unknown error");
}
}
}
function mapStopReason(
status: OpenAI.Responses.ResponseStatus | undefined,
): StopReason {
if (!status) return "stop";
switch (status) {
case "completed":
return "stop";
case "incomplete":
return "length";
case "failed":
case "cancelled":
return "error";
// These two are wonky ...
case "in_progress":
case "queued":
return "stop";
default: {
const _exhaustive: never = status;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}

View file

@ -0,0 +1,309 @@
import OpenAI from "openai";
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
Usage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import {
buildCopilotDynamicHeaders,
hasCopilotVisionInput,
} from "./github-copilot-headers.js";
import {
convertResponsesMessages,
convertResponsesTools,
processResponsesStream,
} from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const OPENAI_TOOL_CALL_PROVIDERS = new Set([
"openai",
"openai-codex",
"opencode",
]);
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(
cacheRetention?: CacheRetention,
): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long"
) {
return "long";
}
return "short";
}
/**
* Get prompt cache retention based on cacheRetention and base URL.
* Only applies to direct OpenAI API calls (api.openai.com).
*/
function getPromptCacheRetention(
baseUrl: string,
cacheRetention: CacheRetention,
): "24h" | undefined {
if (cacheRetention !== "long") {
return undefined;
}
if (baseUrl.includes("api.openai.com")) {
return "24h";
}
return undefined;
}
// OpenAI Responses-specific options
export interface OpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
}
/**
* Generate function for OpenAI Responses API
*/
export const streamOpenAIResponses: StreamFunction<
"openai-responses",
OpenAIResponsesOptions
> = (
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
// Start async processing
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// Create OpenAI client
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, context, apiKey, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const openaiStream = await client.responses.create(
params,
options?.signal ? { signal: options.signal } : undefined,
);
stream.push({ type: "start", partial: output });
await processResponsesStream(openaiStream, output, stream, model, {
serviceTier: options?.serviceTier,
applyServiceTierPricing,
});
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content)
delete (block as { index?: number }).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAIResponses: StreamFunction<
"openai-responses",
SimpleStreamOptions
> = (
model: Model<"openai-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model)
? options?.reasoning
: clampReasoning(options?.reasoning);
return streamOpenAIResponses(model, context, {
...base,
reasoningEffort,
} satisfies OpenAIResponsesOptions);
};
function createClient(
model: Model<"openai-responses">,
context: Context,
apiKey?: string,
optionsHeaders?: Record<string, string>,
) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
const copilotHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
Object.assign(headers, copilotHeaders);
}
// Merge options headers last so they can override defaults
if (optionsHeaders) {
Object.assign(headers, optionsHeaders);
}
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
});
}
function buildParams(
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
) {
const messages = convertResponsesMessages(
model,
context,
OPENAI_TOOL_CALL_PROVIDERS,
);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const params: ResponseCreateParamsStreaming = {
model: model.id,
input: messages,
stream: true,
prompt_cache_key:
cacheRetention === "none" ? undefined : options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(
model.baseUrl,
cacheRetention,
),
store: false,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (options?.serviceTier !== undefined) {
params.service_tier = options.serviceTier;
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools);
}
if (model.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
if (model.name.startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
messages.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
return params;
}
function getServiceTierCostMultiplier(
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
): number {
switch (serviceTier) {
case "flex":
return 0.5;
case "priority":
return 2;
default:
return 1;
}
}
function applyServiceTierPricing(
usage: Usage,
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
) {
const multiplier = getServiceTierCostMultiplier(serviceTier);
if (multiplier === 1) return;
usage.cost.input *= multiplier;
usage.cost.output *= multiplier;
usage.cost.cacheRead *= multiplier;
usage.cost.cacheWrite *= multiplier;
usage.cost.total =
usage.cost.input +
usage.cost.output +
usage.cost.cacheRead +
usage.cost.cacheWrite;
}

View file

@ -0,0 +1,216 @@
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
import type {
AssistantMessage,
AssistantMessageEvent,
Context,
Model,
SimpleStreamOptions,
StreamOptions,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
import {
streamAzureOpenAIResponses,
streamSimpleAzureOpenAIResponses,
} from "./azure-openai-responses.js";
import { streamGoogle, streamSimpleGoogle } from "./google.js";
import {
streamGoogleGeminiCli,
streamSimpleGoogleGeminiCli,
} from "./google-gemini-cli.js";
import {
streamGoogleVertex,
streamSimpleGoogleVertex,
} from "./google-vertex.js";
import { streamMistral, streamSimpleMistral } from "./mistral.js";
import {
streamOpenAICodexResponses,
streamSimpleOpenAICodexResponses,
} from "./openai-codex-responses.js";
import {
streamOpenAICompletions,
streamSimpleOpenAICompletions,
} from "./openai-completions.js";
import {
streamOpenAIResponses,
streamSimpleOpenAIResponses,
} from "./openai-responses.js";
interface BedrockProviderModule {
streamBedrock: (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: StreamOptions,
) => AsyncIterable<AssistantMessageEvent>;
streamSimpleBedrock: (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
) => AsyncIterable<AssistantMessageEvent>;
}
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const BEDROCK_PROVIDER_SPECIFIER = "./amazon-" + "bedrock.js";
let bedrockProviderModuleOverride: BedrockProviderModule | undefined;
export function setBedrockProviderModule(module: BedrockProviderModule): void {
bedrockProviderModuleOverride = module;
}
async function loadBedrockProviderModule(): Promise<BedrockProviderModule> {
if (bedrockProviderModuleOverride) {
return bedrockProviderModuleOverride;
}
const module = await dynamicImport(BEDROCK_PROVIDER_SPECIFIER);
return module as BedrockProviderModule;
}
function forwardStream(
target: AssistantMessageEventStream,
source: AsyncIterable<AssistantMessageEvent>,
): void {
(async () => {
for await (const event of source) {
target.push(event);
}
target.end();
})();
}
function createLazyLoadErrorMessage(
model: Model<"bedrock-converse-stream">,
error: unknown,
): AssistantMessage {
return {
role: "assistant",
content: [],
api: "bedrock-converse-stream",
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "error",
errorMessage: error instanceof Error ? error.message : String(error),
timestamp: Date.now(),
};
}
function streamBedrockLazy(
model: Model<"bedrock-converse-stream">,
context: Context,
options?: StreamOptions,
): AssistantMessageEventStream {
const outer = new AssistantMessageEventStream();
loadBedrockProviderModule()
.then((module) => {
const inner = module.streamBedrock(model, context, options);
forwardStream(outer, inner);
})
.catch((error) => {
const message = createLazyLoadErrorMessage(model, error);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
});
return outer;
}
function streamSimpleBedrockLazy(
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const outer = new AssistantMessageEventStream();
loadBedrockProviderModule()
.then((module) => {
const inner = module.streamSimpleBedrock(model, context, options);
forwardStream(outer, inner);
})
.catch((error) => {
const message = createLazyLoadErrorMessage(model, error);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
});
return outer;
}
export function registerBuiltInApiProviders(): void {
registerApiProvider({
api: "anthropic-messages",
stream: streamAnthropic,
streamSimple: streamSimpleAnthropic,
});
registerApiProvider({
api: "openai-completions",
stream: streamOpenAICompletions,
streamSimple: streamSimpleOpenAICompletions,
});
registerApiProvider({
api: "mistral-conversations",
stream: streamMistral,
streamSimple: streamSimpleMistral,
});
registerApiProvider({
api: "openai-responses",
stream: streamOpenAIResponses,
streamSimple: streamSimpleOpenAIResponses,
});
registerApiProvider({
api: "azure-openai-responses",
stream: streamAzureOpenAIResponses,
streamSimple: streamSimpleAzureOpenAIResponses,
});
registerApiProvider({
api: "openai-codex-responses",
stream: streamOpenAICodexResponses,
streamSimple: streamSimpleOpenAICodexResponses,
});
registerApiProvider({
api: "google-generative-ai",
stream: streamGoogle,
streamSimple: streamSimpleGoogle,
});
registerApiProvider({
api: "google-gemini-cli",
stream: streamGoogleGeminiCli,
streamSimple: streamSimpleGoogleGeminiCli,
});
registerApiProvider({
api: "google-vertex",
stream: streamGoogleVertex,
streamSimple: streamSimpleGoogleVertex,
});
registerApiProvider({
api: "bedrock-converse-stream",
stream: streamBedrockLazy,
streamSimple: streamSimpleBedrockLazy,
});
}
export function resetApiProviders(): void {
clearApiProviders();
registerBuiltInApiProviders();
}
registerBuiltInApiProviders();

View file

@ -0,0 +1,59 @@
import type {
Api,
Model,
SimpleStreamOptions,
StreamOptions,
ThinkingBudgets,
ThinkingLevel,
} from "../types.js";
export function buildBaseOptions(
model: Model<Api>,
options?: SimpleStreamOptions,
apiKey?: string,
): StreamOptions {
return {
temperature: options?.temperature,
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
signal: options?.signal,
apiKey: apiKey || options?.apiKey,
cacheRetention: options?.cacheRetention,
sessionId: options?.sessionId,
headers: options?.headers,
onPayload: options?.onPayload,
maxRetryDelayMs: options?.maxRetryDelayMs,
metadata: options?.metadata,
};
}
export function clampReasoning(
effort: ThinkingLevel | undefined,
): Exclude<ThinkingLevel, "xhigh"> | undefined {
return effort === "xhigh" ? "high" : effort;
}
export function adjustMaxTokensForThinking(
baseMaxTokens: number,
modelMaxTokens: number,
reasoningLevel: ThinkingLevel,
customBudgets?: ThinkingBudgets,
): { maxTokens: number; thinkingBudget: number } {
const defaultBudgets: ThinkingBudgets = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
};
const budgets = { ...defaultBudgets, ...customBudgets };
const minOutputTokens = 1024;
const level = clampReasoning(reasoningLevel)!;
let thinkingBudget = budgets[level]!;
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
if (maxTokens <= thinkingBudget) {
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
}
return { maxTokens, thinkingBudget };
}

View file

@ -0,0 +1,193 @@
import type {
Api,
AssistantMessage,
Message,
Model,
ToolCall,
ToolResultMessage,
} from "../types.js";
/**
* Normalize tool call ID for cross-provider compatibility.
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
*/
export function transformMessages<TApi extends Api>(
messages: Message[],
model: Model<TApi>,
normalizeToolCallId?: (
id: string,
model: Model<TApi>,
source: AssistantMessage,
) => string,
): Message[] {
// Build a map of original tool call IDs to normalized IDs
const toolCallIdMap = new Map<string, string>();
// First pass: transform messages (thinking blocks, tool call ID normalization)
const transformed = messages.map((msg) => {
// User messages pass through unchanged
if (msg.role === "user") {
return msg;
}
// Handle toolResult messages - normalize toolCallId if we have a mapping
if (msg.role === "toolResult") {
const normalizedId = toolCallIdMap.get(msg.toolCallId);
if (normalizedId && normalizedId !== msg.toolCallId) {
return { ...msg, toolCallId: normalizedId };
}
return msg;
}
// Assistant messages need transformation check
if (msg.role === "assistant") {
const assistantMsg = msg as AssistantMessage;
const isSameModel =
assistantMsg.provider === model.provider &&
assistantMsg.api === model.api &&
assistantMsg.model === model.id;
const transformedContent = assistantMsg.content.flatMap((block) => {
if (block.type === "thinking") {
// Redacted thinking is opaque encrypted content, only valid for the same model.
// Drop it for cross-model to avoid API errors.
if (block.redacted) {
return isSameModel ? block : [];
}
// For same model: keep thinking blocks with signatures (needed for replay)
// even if the thinking text is empty (OpenAI encrypted reasoning)
if (isSameModel && block.thinkingSignature) return block;
// Skip empty thinking blocks, convert others to plain text
if (!block.thinking || block.thinking.trim() === "") return [];
if (isSameModel) return block;
return {
type: "text" as const,
text: block.thinking,
};
}
if (block.type === "text") {
if (isSameModel) return block;
return {
type: "text" as const,
text: block.text,
};
}
if (block.type === "toolCall") {
const toolCall = block as ToolCall;
let normalizedToolCall: ToolCall = toolCall;
if (!isSameModel && toolCall.thoughtSignature) {
normalizedToolCall = { ...toolCall };
delete (normalizedToolCall as { thoughtSignature?: string })
.thoughtSignature;
}
if (!isSameModel && normalizeToolCallId) {
const normalizedId = normalizeToolCallId(
toolCall.id,
model,
assistantMsg,
);
if (normalizedId !== toolCall.id) {
toolCallIdMap.set(toolCall.id, normalizedId);
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
}
}
return normalizedToolCall;
}
return block;
});
return {
...assistantMsg,
content: transformedContent,
};
}
return msg;
});
// Second pass: insert synthetic empty tool results for orphaned tool calls
// This preserves thinking signatures and satisfies API requirements
const result: Message[] = [];
let pendingToolCalls: ToolCall[] = [];
let existingToolResultIds = new Set<string>();
for (let i = 0; i < transformed.length; i++) {
const msg = transformed[i];
if (msg.role === "assistant") {
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
if (pendingToolCalls.length > 0) {
for (const tc of pendingToolCalls) {
if (!existingToolResultIds.has(tc.id)) {
result.push({
role: "toolResult",
toolCallId: tc.id,
toolName: tc.name,
content: [{ type: "text", text: "No result provided" }],
isError: true,
timestamp: Date.now(),
} as ToolResultMessage);
}
}
pendingToolCalls = [];
existingToolResultIds = new Set();
}
// Skip errored/aborted assistant messages entirely.
// These are incomplete turns that shouldn't be replayed:
// - May have partial content (reasoning without message, incomplete tool calls)
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
// - The model should retry from the last valid state
const assistantMsg = msg as AssistantMessage;
if (
assistantMsg.stopReason === "error" ||
assistantMsg.stopReason === "aborted"
) {
continue;
}
// Track tool calls from this assistant message
const toolCalls = assistantMsg.content.filter(
(b) => b.type === "toolCall",
) as ToolCall[];
if (toolCalls.length > 0) {
pendingToolCalls = toolCalls;
existingToolResultIds = new Set();
}
result.push(msg);
} else if (msg.role === "toolResult") {
existingToolResultIds.add(msg.toolCallId);
result.push(msg);
} else if (msg.role === "user") {
// User message interrupts tool flow - insert synthetic results for orphaned calls
if (pendingToolCalls.length > 0) {
for (const tc of pendingToolCalls) {
if (!existingToolResultIds.has(tc.id)) {
result.push({
role: "toolResult",
toolCallId: tc.id,
toolName: tc.name,
content: [{ type: "text", text: "No result provided" }],
isError: true,
timestamp: Date.now(),
} as ToolResultMessage);
}
}
pendingToolCalls = [];
existingToolResultIds = new Set();
}
result.push(msg);
} else {
result.push(msg);
}
}
return result;
}

59
packages/ai/src/stream.ts Normal file
View file

@ -0,0 +1,59 @@
import "./providers/register-builtins.js";
import { getApiProvider } from "./api-registry.js";
import type {
Api,
AssistantMessage,
AssistantMessageEventStream,
Context,
Model,
ProviderStreamOptions,
SimpleStreamOptions,
StreamOptions,
} from "./types.js";
export { getEnvApiKey } from "./env-api-keys.js";
function resolveApiProvider(api: Api) {
const provider = getApiProvider(api);
if (!provider) {
throw new Error(`No API provider registered for api: ${api}`);
}
return provider;
}
export function stream<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: ProviderStreamOptions,
): AssistantMessageEventStream {
const provider = resolveApiProvider(model.api);
return provider.stream(model, context, options as StreamOptions);
}
export async function complete<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: ProviderStreamOptions,
): Promise<AssistantMessage> {
const s = stream(model, context, options);
return s.result();
}
export function streamSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const provider = resolveApiProvider(model.api);
return provider.streamSimple(model, context, options);
}
export async function completeSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleStreamOptions,
): Promise<AssistantMessage> {
const s = streamSimple(model, context, options);
return s.result();
}

361
packages/ai/src/types.ts Normal file
View file

@ -0,0 +1,361 @@
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
export type KnownApi =
| "openai-completions"
| "mistral-conversations"
| "openai-responses"
| "azure-openai-responses"
| "openai-codex-responses"
| "anthropic-messages"
| "bedrock-converse-stream"
| "google-generative-ai"
| "google-gemini-cli"
| "google-vertex";
export type Api = KnownApi | (string & {});
export type KnownProvider =
| "amazon-bedrock"
| "anthropic"
| "google"
| "google-gemini-cli"
| "google-antigravity"
| "google-vertex"
| "openai"
| "azure-openai-responses"
| "openai-codex"
| "github-copilot"
| "xai"
| "groq"
| "cerebras"
| "openrouter"
| "vercel-ai-gateway"
| "zai"
| "mistral"
| "minimax"
| "minimax-cn"
| "huggingface"
| "opencode"
| "opencode-go"
| "kimi-coding";
export type Provider = KnownProvider | string;
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
/** Token budgets for each thinking level (token-based providers only) */
export interface ThinkingBudgets {
minimal?: number;
low?: number;
medium?: number;
high?: number;
}
// Base options all providers share
export type CacheRetention = "none" | "short" | "long";
export type Transport = "sse" | "websocket" | "auto";
export interface StreamOptions {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
apiKey?: string;
/**
* Preferred transport for providers that support multiple transports.
* Providers that do not support this option ignore it.
*/
transport?: Transport;
/**
* Prompt cache retention preference. Providers map this to their supported values.
* Default: "short".
*/
cacheRetention?: CacheRetention;
/**
* Optional session identifier for providers that support session-based caching.
* Providers can use this to enable prompt caching, request routing, or other
* session-aware features. Ignored by providers that don't support it.
*/
sessionId?: string;
/**
* Optional callback for inspecting provider payloads before sending.
*/
onPayload?: (payload: unknown) => void;
/**
* Optional custom HTTP headers to include in API requests.
* Merged with provider defaults; can override default headers.
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
*/
headers?: Record<string, string>;
/**
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
* If the server's requested delay exceeds this value, the request fails immediately
* with an error containing the requested delay, allowing higher-level retry logic
* to handle it with user visibility.
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
*/
maxRetryDelayMs?: number;
/**
* Optional metadata to include in API requests.
* Providers extract the fields they understand and ignore the rest.
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
*/
metadata?: Record<string, unknown>;
}
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
// Unified options with reasoning passed to streamSimple() and completeSimple()
export interface SimpleStreamOptions extends StreamOptions {
reasoning?: ThinkingLevel;
/** Custom token budgets for thinking levels (token-based providers only) */
thinkingBudgets?: ThinkingBudgets;
}
// Generic StreamFunction with typed options
export type StreamFunction<
TApi extends Api = Api,
TOptions extends StreamOptions = StreamOptions,
> = (
model: Model<TApi>,
context: Context,
options?: TOptions,
) => AssistantMessageEventStream;
export interface TextSignatureV1 {
v: 1;
id: string;
phase?: "commentary" | "final_answer";
}
export interface TextContent {
type: "text";
text: string;
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
}
export interface ThinkingContent {
type: "thinking";
thinking: string;
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
/** When true, the thinking content was redacted by safety filters. The opaque
* encrypted payload is stored in `thinkingSignature` so it can be passed back
* to the API for multi-turn continuity. */
redacted?: boolean;
}
export interface ImageContent {
type: "image";
data: string; // base64 encoded image data
mimeType: string; // e.g., "image/jpeg", "image/png"
}
export interface ToolCall {
type: "toolCall";
id: string;
name: string;
arguments: Record<string, any>;
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
}
export interface Usage {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
totalTokens: number;
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
total: number;
};
}
export type StopReason = "stop" | "length" | "toolUse" | "error" | "aborted";
export interface UserMessage {
role: "user";
content: string | (TextContent | ImageContent)[];
timestamp: number; // Unix timestamp in milliseconds
}
export interface AssistantMessage {
role: "assistant";
content: (TextContent | ThinkingContent | ToolCall)[];
api: Api;
provider: Provider;
model: string;
usage: Usage;
stopReason: StopReason;
errorMessage?: string;
timestamp: number; // Unix timestamp in milliseconds
}
export interface ToolResultMessage<TDetails = any> {
role: "toolResult";
toolCallId: string;
toolName: string;
content: (TextContent | ImageContent)[]; // Supports text and images
details?: TDetails;
isError: boolean;
timestamp: number; // Unix timestamp in milliseconds
}
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
import type { TSchema } from "@sinclair/typebox";
export interface Tool<TParameters extends TSchema = TSchema> {
name: string;
description: string;
parameters: TParameters;
}
export interface Context {
systemPrompt?: string;
messages: Message[];
tools?: Tool[];
}
export type AssistantMessageEvent =
| { type: "start"; partial: AssistantMessage }
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
| {
type: "text_delta";
contentIndex: number;
delta: string;
partial: AssistantMessage;
}
| {
type: "text_end";
contentIndex: number;
content: string;
partial: AssistantMessage;
}
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
| {
type: "thinking_delta";
contentIndex: number;
delta: string;
partial: AssistantMessage;
}
| {
type: "thinking_end";
contentIndex: number;
content: string;
partial: AssistantMessage;
}
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
| {
type: "toolcall_delta";
contentIndex: number;
delta: string;
partial: AssistantMessage;
}
| {
type: "toolcall_end";
contentIndex: number;
toolCall: ToolCall;
partial: AssistantMessage;
}
| {
type: "done";
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
message: AssistantMessage;
}
| {
type: "error";
reason: Extract<StopReason, "aborted" | "error">;
error: AssistantMessage;
};
/**
* Compatibility settings for OpenAI-compatible completions APIs.
* Use this to override URL-based auto-detection for custom providers.
*/
export interface OpenAICompletionsCompat {
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
supportsStore?: boolean;
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
supportsDeveloperRole?: boolean;
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
supportsReasoningEffort?: boolean;
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
supportsUsageInStreaming?: boolean;
/** Which field to use for max tokens. Default: auto-detected from URL. */
maxTokensField?: "max_completion_tokens" | "max_tokens";
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
requiresToolResultName?: boolean;
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
requiresAssistantAfterToolResult?: boolean;
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
requiresThinkingAsText?: boolean;
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "zai" uses thinking: { type: "enabled" }, "qwen" uses enable_thinking: boolean. Default: "openai". */
thinkingFormat?: "openai" | "zai" | "qwen";
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
openRouterRouting?: OpenRouterRouting;
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
vercelGatewayRouting?: VercelGatewayRouting;
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
supportsStrictMode?: boolean;
}
/** Compatibility settings for OpenAI Responses APIs. */
export interface OpenAIResponsesCompat {
// Reserved for future use
}
/**
* OpenRouter provider routing preferences.
* Controls which upstream providers OpenRouter routes requests to.
* @see https://openrouter.ai/docs/provider-routing
*/
export interface OpenRouterRouting {
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
only?: string[];
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
order?: string[];
}
/**
* Vercel AI Gateway routing preferences.
* Controls which upstream providers the gateway routes requests to.
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
*/
export interface VercelGatewayRouting {
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
only?: string[];
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
order?: string[];
}
// Model interface for the unified model system
export interface Model<TApi extends Api> {
id: string;
name: string;
api: TApi;
provider: Provider;
baseUrl: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number; // $/million tokens
output: number; // $/million tokens
cacheRead: number; // $/million tokens
cacheWrite: number; // $/million tokens
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
compat?: TApi extends "openai-completions"
? OpenAICompletionsCompat
: TApi extends "openai-responses"
? OpenAIResponsesCompat
: never;
}

View file

@ -0,0 +1,92 @@
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
// Generic event stream class for async iteration
export class EventStream<T, R = T> implements AsyncIterable<T> {
private queue: T[] = [];
private waiting: ((value: IteratorResult<T>) => void)[] = [];
private done = false;
private finalResultPromise: Promise<R>;
private resolveFinalResult!: (result: R) => void;
constructor(
private isComplete: (event: T) => boolean,
private extractResult: (event: T) => R,
) {
this.finalResultPromise = new Promise((resolve) => {
this.resolveFinalResult = resolve;
});
}
push(event: T): void {
if (this.done) return;
if (this.isComplete(event)) {
this.done = true;
this.resolveFinalResult(this.extractResult(event));
}
// Deliver to waiting consumer or queue it
const waiter = this.waiting.shift();
if (waiter) {
waiter({ value: event, done: false });
} else {
this.queue.push(event);
}
}
end(result?: R): void {
this.done = true;
if (result !== undefined) {
this.resolveFinalResult(result);
}
// Notify all waiting consumers that we're done
while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!;
waiter({ value: undefined as any, done: true });
}
}
async *[Symbol.asyncIterator](): AsyncIterator<T> {
while (true) {
if (this.queue.length > 0) {
yield this.queue.shift()!;
} else if (this.done) {
return;
} else {
const result = await new Promise<IteratorResult<T>>((resolve) =>
this.waiting.push(resolve),
);
if (result.done) return;
yield result.value;
}
}
}
result(): Promise<R> {
return this.finalResultPromise;
}
}
export class AssistantMessageEventStream extends EventStream<
AssistantMessageEvent,
AssistantMessage
> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") {
return event.message;
} else if (event.type === "error") {
return event.error;
}
throw new Error("Unexpected event type for final result");
},
);
}
}
/** Factory function for AssistantMessageEventStream (for use in extensions) */
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
return new AssistantMessageEventStream();
}

View file

@ -0,0 +1,17 @@
/** Fast deterministic hash to shorten long strings */
export function shortHash(str: string): string {
let h1 = 0xdeadbeef;
let h2 = 0x41c6ce57;
for (let i = 0; i < str.length; i++) {
const ch = str.charCodeAt(i);
h1 = Math.imul(h1 ^ ch, 2654435761);
h2 = Math.imul(h2 ^ ch, 1597334677);
}
h1 =
Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^
Math.imul(h2 ^ (h2 >>> 13), 3266489909);
h2 =
Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^
Math.imul(h1 ^ (h1 >>> 13), 3266489909);
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
}

View file

@ -0,0 +1,30 @@
import { parse as partialParse } from "partial-json";
/**
* Attempts to parse potentially incomplete JSON during streaming.
* Always returns a valid object, even if the JSON is incomplete.
*
* @param partialJson The partial JSON string from streaming
* @returns Parsed object or empty object if parsing fails
*/
export function parseStreamingJson<T = any>(
partialJson: string | undefined,
): T {
if (!partialJson || partialJson.trim() === "") {
return {} as T;
}
// Try standard parsing first (fastest for complete JSON)
try {
return JSON.parse(partialJson) as T;
} catch {
// Try partial-json for incomplete JSON
try {
const result = partialParse(partialJson);
return (result ?? {}) as T;
} catch {
// If all parsing fails, return empty object
return {} as T;
}
}
}

View file

@ -0,0 +1,144 @@
/**
* Anthropic OAuth flow (Claude Pro/Max)
*/
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
const SCOPES = "org:create_api_key user:profile user:inference";
/**
* Login with Anthropic OAuth (device code flow)
*
* @param onAuthUrl - Callback to handle the authorization URL (e.g., open browser)
* @param onPromptCode - Callback to prompt user for the authorization code
*/
export async function loginAnthropic(
onAuthUrl: (url: string) => void,
onPromptCode: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Build authorization URL
const authParams = new URLSearchParams({
code: "true",
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES,
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
});
const authUrl = `${AUTHORIZE_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuthUrl(authUrl);
// Wait for user to paste authorization code (format: code#state)
const authCode = await onPromptCode();
const splits = authCode.split("#");
const code = splits[0];
const state = splits[1];
// Exchange code for tokens
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code: code,
state: state,
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
// Save credentials
return {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
};
}
/**
* Refresh Anthropic OAuth token
*/
export async function refreshAnthropicToken(
refreshToken: string,
): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
grant_type: "refresh_token",
client_id: CLIENT_ID,
refresh_token: refreshToken,
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Anthropic token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
return {
refresh: data.refresh_token,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
};
}
export const anthropicOAuthProvider: OAuthProviderInterface = {
id: "anthropic",
name: "Anthropic (Claude Pro/Max)",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshAnthropicToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -0,0 +1,423 @@
/**
* GitHub Copilot OAuth flow
*/
import { getModels } from "../../models.js";
import type { Api, Model } from "../../types.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
type CopilotCredentials = OAuthCredentials & {
enterpriseUrl?: string;
};
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
const COPILOT_HEADERS = {
"User-Agent": "GitHubCopilotChat/0.35.0",
"Editor-Version": "vscode/1.107.0",
"Editor-Plugin-Version": "copilot-chat/0.35.0",
"Copilot-Integration-Id": "vscode-chat",
} as const;
type DeviceCodeResponse = {
device_code: string;
user_code: string;
verification_uri: string;
interval: number;
expires_in: number;
};
type DeviceTokenSuccessResponse = {
access_token: string;
token_type?: string;
scope?: string;
};
type DeviceTokenErrorResponse = {
error: string;
error_description?: string;
interval?: number;
};
export function normalizeDomain(input: string): string | null {
const trimmed = input.trim();
if (!trimmed) return null;
try {
const url = trimmed.includes("://")
? new URL(trimmed)
: new URL(`https://${trimmed}`);
return url.hostname;
} catch {
return null;
}
}
function getUrls(domain: string): {
deviceCodeUrl: string;
accessTokenUrl: string;
copilotTokenUrl: string;
} {
return {
deviceCodeUrl: `https://${domain}/login/device/code`,
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
};
}
/**
* Parse the proxy-ep from a Copilot token and convert to API base URL.
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
* Returns API URL like https://api.individual.githubcopilot.com
*/
function getBaseUrlFromToken(token: string): string | null {
const match = token.match(/proxy-ep=([^;]+)/);
if (!match) return null;
const proxyHost = match[1];
// Convert proxy.xxx to api.xxx
const apiHost = proxyHost.replace(/^proxy\./, "api.");
return `https://${apiHost}`;
}
export function getGitHubCopilotBaseUrl(
token?: string,
enterpriseDomain?: string,
): string {
// If we have a token, extract the base URL from proxy-ep
if (token) {
const urlFromToken = getBaseUrlFromToken(token);
if (urlFromToken) return urlFromToken;
}
// Fallback for enterprise or if token parsing fails
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
return "https://api.individual.githubcopilot.com";
}
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
const response = await fetch(url, init);
if (!response.ok) {
const text = await response.text();
throw new Error(`${response.status} ${response.statusText}: ${text}`);
}
return response.json();
}
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
const urls = getUrls(domain);
const data = await fetchJson(urls.deviceCodeUrl, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": "GitHubCopilotChat/0.35.0",
},
body: JSON.stringify({
client_id: CLIENT_ID,
scope: "read:user",
}),
});
if (!data || typeof data !== "object") {
throw new Error("Invalid device code response");
}
const deviceCode = (data as Record<string, unknown>).device_code;
const userCode = (data as Record<string, unknown>).user_code;
const verificationUri = (data as Record<string, unknown>).verification_uri;
const interval = (data as Record<string, unknown>).interval;
const expiresIn = (data as Record<string, unknown>).expires_in;
if (
typeof deviceCode !== "string" ||
typeof userCode !== "string" ||
typeof verificationUri !== "string" ||
typeof interval !== "number" ||
typeof expiresIn !== "number"
) {
throw new Error("Invalid device code response fields");
}
return {
device_code: deviceCode,
user_code: userCode,
verification_uri: verificationUri,
interval,
expires_in: expiresIn,
};
}
/**
* Sleep that can be interrupted by an AbortSignal
*/
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Login cancelled"));
return;
}
const timeout = setTimeout(resolve, ms);
signal?.addEventListener(
"abort",
() => {
clearTimeout(timeout);
reject(new Error("Login cancelled"));
},
{ once: true },
);
});
}
async function pollForGitHubAccessToken(
domain: string,
deviceCode: string,
intervalSeconds: number,
expiresIn: number,
signal?: AbortSignal,
) {
const urls = getUrls(domain);
const deadline = Date.now() + expiresIn * 1000;
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
while (Date.now() < deadline) {
if (signal?.aborted) {
throw new Error("Login cancelled");
}
const raw = await fetchJson(urls.accessTokenUrl, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": "GitHubCopilotChat/0.35.0",
},
body: JSON.stringify({
client_id: CLIENT_ID,
device_code: deviceCode,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
}),
});
if (
raw &&
typeof raw === "object" &&
typeof (raw as DeviceTokenSuccessResponse).access_token === "string"
) {
return (raw as DeviceTokenSuccessResponse).access_token;
}
if (
raw &&
typeof raw === "object" &&
typeof (raw as DeviceTokenErrorResponse).error === "string"
) {
const err = (raw as DeviceTokenErrorResponse).error;
if (err === "authorization_pending") {
await abortableSleep(intervalMs, signal);
continue;
}
if (err === "slow_down") {
intervalMs += 5000;
await abortableSleep(intervalMs, signal);
continue;
}
throw new Error(`Device flow failed: ${err}`);
}
await abortableSleep(intervalMs, signal);
}
throw new Error("Device flow timed out");
}
/**
* Refresh GitHub Copilot token
*/
export async function refreshGitHubCopilotToken(
refreshToken: string,
enterpriseDomain?: string,
): Promise<OAuthCredentials> {
const domain = enterpriseDomain || "github.com";
const urls = getUrls(domain);
const raw = await fetchJson(urls.copilotTokenUrl, {
headers: {
Accept: "application/json",
Authorization: `Bearer ${refreshToken}`,
...COPILOT_HEADERS,
},
});
if (!raw || typeof raw !== "object") {
throw new Error("Invalid Copilot token response");
}
const token = (raw as Record<string, unknown>).token;
const expiresAt = (raw as Record<string, unknown>).expires_at;
if (typeof token !== "string" || typeof expiresAt !== "number") {
throw new Error("Invalid Copilot token response fields");
}
return {
refresh: refreshToken,
access: token,
expires: expiresAt * 1000 - 5 * 60 * 1000,
enterpriseUrl: enterpriseDomain,
};
}
/**
* Enable a model for the user's GitHub Copilot account.
* This is required for some models (like Claude, Grok) before they can be used.
*/
async function enableGitHubCopilotModel(
token: string,
modelId: string,
enterpriseDomain?: string,
): Promise<boolean> {
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
const url = `${baseUrl}/models/${modelId}/policy`;
try {
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
...COPILOT_HEADERS,
"openai-intent": "chat-policy",
"x-interaction-type": "chat-policy",
},
body: JSON.stringify({ state: "enabled" }),
});
return response.ok;
} catch {
return false;
}
}
/**
* Enable all known GitHub Copilot models that may require policy acceptance.
* Called after successful login to ensure all models are available.
*/
async function enableAllGitHubCopilotModels(
token: string,
enterpriseDomain?: string,
onProgress?: (model: string, success: boolean) => void,
): Promise<void> {
const models = getModels("github-copilot");
await Promise.all(
models.map(async (model) => {
const success = await enableGitHubCopilotModel(
token,
model.id,
enterpriseDomain,
);
onProgress?.(model.id, success);
}),
);
}
/**
* Login with GitHub Copilot OAuth (device code flow)
*
* @param options.onAuth - Callback with URL and optional instructions (user code)
* @param options.onPrompt - Callback to prompt user for input
* @param options.onProgress - Optional progress callback
* @param options.signal - Optional AbortSignal for cancellation
*/
export async function loginGitHubCopilot(options: {
onAuth: (url: string, instructions?: string) => void;
onPrompt: (prompt: {
message: string;
placeholder?: string;
allowEmpty?: boolean;
}) => Promise<string>;
onProgress?: (message: string) => void;
signal?: AbortSignal;
}): Promise<OAuthCredentials> {
const input = await options.onPrompt({
message: "GitHub Enterprise URL/domain (blank for github.com)",
placeholder: "company.ghe.com",
allowEmpty: true,
});
if (options.signal?.aborted) {
throw new Error("Login cancelled");
}
const trimmed = input.trim();
const enterpriseDomain = normalizeDomain(input);
if (trimmed && !enterpriseDomain) {
throw new Error("Invalid GitHub Enterprise URL/domain");
}
const domain = enterpriseDomain || "github.com";
const device = await startDeviceFlow(domain);
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
const githubAccessToken = await pollForGitHubAccessToken(
domain,
device.device_code,
device.interval,
device.expires_in,
options.signal,
);
const credentials = await refreshGitHubCopilotToken(
githubAccessToken,
enterpriseDomain ?? undefined,
);
// Enable all models after successful login
options.onProgress?.("Enabling models...");
await enableAllGitHubCopilotModels(
credentials.access,
enterpriseDomain ?? undefined,
);
return credentials;
}
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
id: "github-copilot",
name: "GitHub Copilot",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as CopilotCredentials;
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
modifyModels(
models: Model<Api>[],
credentials: OAuthCredentials,
): Model<Api>[] {
const creds = credentials as CopilotCredentials;
const domain = creds.enterpriseUrl
? (normalizeDomain(creds.enterpriseUrl) ?? undefined)
: undefined;
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
return models.map((m) =>
m.provider === "github-copilot" ? { ...m, baseUrl } : m,
);
},
};

View file

@ -0,0 +1,492 @@
/**
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import type { Server } from "node:http";
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
type AntigravityCredentials = OAuthCredentials & {
projectId: string;
};
let _createServer: typeof import("node:http").createServer | null = null;
let _httpImportPromise: Promise<void> | null = null;
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
_httpImportPromise = import("node:http").then((m) => {
_createServer = m.createServer;
});
}
// Antigravity OAuth credentials (different from Gemini CLI)
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
);
const CLIENT_SECRET = decode(
"R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=",
);
const REDIRECT_URI = "http://localhost:51121/oauth-callback";
// Antigravity requires additional scopes
const SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
];
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
// Fallback project ID when discovery fails
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
type CallbackServerInfo = {
server: Server;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string; state: string } | null>;
};
/**
* Start a local HTTP server to receive the OAuth callback
*/
async function getNodeCreateServer(): Promise<
typeof import("node:http").createServer
> {
if (_createServer) return _createServer;
if (_httpImportPromise) {
await _httpImportPromise;
}
if (_createServer) return _createServer;
throw new Error(
"Antigravity OAuth is only available in Node.js environments",
);
}
async function startCallbackServer(): Promise<CallbackServerInfo> {
const createServer = await getNodeCreateServer();
return new Promise((resolve, reject) => {
let result: { code: string; state: string } | null = null;
let cancelled = false;
const server = createServer((req, res) => {
const url = new URL(req.url || "", `http://localhost:51121`);
if (url.pathname === "/oauth-callback") {
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
if (error) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
);
return;
}
if (code && state) {
res.writeHead(200, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
);
result = { code, state };
} else {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
);
}
} else {
res.writeHead(404);
res.end();
}
});
server.on("error", (err) => {
reject(err);
});
server.listen(51121, "127.0.0.1", () => {
resolve({
server,
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
while (!result && !cancelled) {
await sleep();
}
return result;
},
});
});
});
}
/**
* Parse redirect URL to extract code and state
*/
function parseRedirectUrl(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// Not a URL, return empty
return {};
}
}
interface LoadCodeAssistPayload {
cloudaicompanionProject?: string | { id?: string };
currentTier?: { id?: string };
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
}
/**
* Discover or provision a project for the user
*/
async function discoverProject(
accessToken: string,
onProgress?: (message: string) => void,
): Promise<string> {
const headers = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
"User-Agent": "google-api-nodejs-client/9.15.1",
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"Client-Metadata": JSON.stringify({
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
}),
};
// Try endpoints in order: prod first, then sandbox
const endpoints = [
"https://cloudcode-pa.googleapis.com",
"https://daily-cloudcode-pa.sandbox.googleapis.com",
];
onProgress?.("Checking for existing project...");
for (const endpoint of endpoints) {
try {
const loadResponse = await fetch(
`${endpoint}/v1internal:loadCodeAssist`,
{
method: "POST",
headers,
body: JSON.stringify({
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
},
}),
},
);
if (loadResponse.ok) {
const data = (await loadResponse.json()) as LoadCodeAssistPayload;
// Handle both string and object formats
if (
typeof data.cloudaicompanionProject === "string" &&
data.cloudaicompanionProject
) {
return data.cloudaicompanionProject;
}
if (
data.cloudaicompanionProject &&
typeof data.cloudaicompanionProject === "object" &&
data.cloudaicompanionProject.id
) {
return data.cloudaicompanionProject.id;
}
}
} catch {
// Try next endpoint
}
}
// Use fallback project ID
onProgress?.("Using default project...");
return DEFAULT_PROJECT_ID;
}
/**
* Get user email from the access token
*/
async function getUserEmail(accessToken: string): Promise<string | undefined> {
try {
const response = await fetch(
"https://www.googleapis.com/oauth2/v1/userinfo?alt=json",
{
headers: {
Authorization: `Bearer ${accessToken}`,
},
},
);
if (response.ok) {
const data = (await response.json()) as { email?: string };
return data.email;
}
} catch {
// Ignore errors, email is optional
}
return undefined;
}
/**
* Refresh Antigravity token
*/
export async function refreshAntigravityToken(
refreshToken: string,
projectId: string,
): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
refresh_token: refreshToken,
grant_type: "refresh_token",
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Antigravity token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
expires_in: number;
refresh_token?: string;
};
return {
refresh: data.refresh_token || refreshToken,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
projectId,
};
}
/**
* Login with Antigravity OAuth
*
* @param onAuth - Callback with URL and optional instructions
* @param onProgress - Optional progress callback
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
* Races with browser callback - whichever completes first wins.
*/
export async function loginAntigravity(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
onManualCodeInput?: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");
const server = await startCallbackServer();
let code: string | undefined;
try {
// Build authorization URL
const authParams = new URLSearchParams({
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES.join(" "),
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
access_type: "offline",
prompt: "consent",
});
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuth({
url: authUrl,
instructions: "Complete the sign-in in your browser.",
});
// Wait for the callback, racing with manual input if provided
onProgress?.("Waiting for OAuth callback...");
if (onManualCodeInput) {
// Race between browser callback and manual input
let manualInput: string | undefined;
let manualError: Error | undefined;
const manualPromise = onManualCodeInput()
.then((input) => {
manualInput = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won - verify state
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
} else if (manualInput) {
// Manual input won
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
// If still no code, wait for manual promise and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualInput) {
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
}
} else {
// Original flow: just wait for callback
const result = await server.waitForCode();
if (result?.code) {
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
}
}
if (!code) {
throw new Error("No authorization code received");
}
// Exchange code for tokens
onProgress?.("Exchanging authorization code for tokens...");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code,
grant_type: "authorization_code",
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
if (!tokenData.refresh_token) {
throw new Error("No refresh token received. Please try again.");
}
// Get user email
onProgress?.("Getting user info...");
const email = await getUserEmail(tokenData.access_token);
// Discover project
const projectId = await discoverProject(tokenData.access_token, onProgress);
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
const credentials: OAuthCredentials = {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
projectId,
email,
};
return credentials;
} finally {
server.server.close();
}
}
export const antigravityOAuthProvider: OAuthProviderInterface = {
id: "google-antigravity",
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAntigravity(
callbacks.onAuth,
callbacks.onProgress,
callbacks.onManualCodeInput,
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as AntigravityCredentials;
if (!creds.projectId) {
throw new Error("Antigravity credentials missing projectId");
}
return refreshAntigravityToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as AntigravityCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -0,0 +1,648 @@
/**
* Gemini CLI OAuth flow (Google Cloud Code Assist)
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import type { Server } from "node:http";
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
type GeminiCredentials = OAuthCredentials & {
projectId: string;
};
let _createServer: typeof import("node:http").createServer | null = null;
let _httpImportPromise: Promise<void> | null = null;
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
_httpImportPromise = import("node:http").then((m) => {
_createServer = m.createServer;
});
}
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
);
const CLIENT_SECRET = decode(
"R09DU1BYLTR1SGdNUG0tMW83U2stZ2VWNkN1NWNsWEZzeGw=",
);
const REDIRECT_URI = "http://localhost:8085/oauth2callback";
const SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
];
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
type CallbackServerInfo = {
server: Server;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string; state: string } | null>;
};
/**
* Start a local HTTP server to receive the OAuth callback
*/
async function getNodeCreateServer(): Promise<
typeof import("node:http").createServer
> {
if (_createServer) return _createServer;
if (_httpImportPromise) {
await _httpImportPromise;
}
if (_createServer) return _createServer;
throw new Error("Gemini CLI OAuth is only available in Node.js environments");
}
async function startCallbackServer(): Promise<CallbackServerInfo> {
const createServer = await getNodeCreateServer();
return new Promise((resolve, reject) => {
let result: { code: string; state: string } | null = null;
let cancelled = false;
const server = createServer((req, res) => {
const url = new URL(req.url || "", `http://localhost:8085`);
if (url.pathname === "/oauth2callback") {
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
if (error) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
);
return;
}
if (code && state) {
res.writeHead(200, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
);
result = { code, state };
} else {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
);
}
} else {
res.writeHead(404);
res.end();
}
});
server.on("error", (err) => {
reject(err);
});
server.listen(8085, "127.0.0.1", () => {
resolve({
server,
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
while (!result && !cancelled) {
await sleep();
}
return result;
},
});
});
});
}
/**
* Parse redirect URL to extract code and state
*/
function parseRedirectUrl(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// Not a URL, return empty
return {};
}
}
interface LoadCodeAssistPayload {
cloudaicompanionProject?: string;
currentTier?: { id?: string };
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
}
/**
* Long-running operation response from onboardUser
*/
interface LongRunningOperationResponse {
name?: string;
done?: boolean;
response?: {
cloudaicompanionProject?: { id?: string };
};
}
// Tier IDs as used by the Cloud Code API
const TIER_FREE = "free-tier";
const TIER_LEGACY = "legacy-tier";
const TIER_STANDARD = "standard-tier";
interface GoogleRpcErrorResponse {
error?: {
details?: Array<{ reason?: string }>;
};
}
/**
* Wait helper for onboarding retries
*/
function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
/**
* Get default tier from allowed tiers
*/
function getDefaultTier(
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>,
): { id?: string } {
if (!allowedTiers || allowedTiers.length === 0) return { id: TIER_LEGACY };
const defaultTier = allowedTiers.find((t) => t.isDefault);
return defaultTier ?? { id: TIER_LEGACY };
}
function isVpcScAffectedUser(payload: unknown): boolean {
if (!payload || typeof payload !== "object") return false;
if (!("error" in payload)) return false;
const error = (payload as GoogleRpcErrorResponse).error;
if (!error?.details || !Array.isArray(error.details)) return false;
return error.details.some(
(detail) => detail.reason === "SECURITY_POLICY_VIOLATED",
);
}
/**
* Poll a long-running operation until completion
*/
async function pollOperation(
operationName: string,
headers: Record<string, string>,
onProgress?: (message: string) => void,
): Promise<LongRunningOperationResponse> {
let attempt = 0;
while (true) {
if (attempt > 0) {
onProgress?.(
`Waiting for project provisioning (attempt ${attempt + 1})...`,
);
await wait(5000);
}
const response = await fetch(
`${CODE_ASSIST_ENDPOINT}/v1internal/${operationName}`,
{
method: "GET",
headers,
},
);
if (!response.ok) {
throw new Error(
`Failed to poll operation: ${response.status} ${response.statusText}`,
);
}
const data = (await response.json()) as LongRunningOperationResponse;
if (data.done) {
return data;
}
attempt += 1;
}
}
/**
* Discover or provision a Google Cloud project for the user
*/
async function discoverProject(
accessToken: string,
onProgress?: (message: string) => void,
): Promise<string> {
// Check for user-provided project ID via environment variable
const envProjectId =
process.env.GOOGLE_CLOUD_PROJECT || process.env.GOOGLE_CLOUD_PROJECT_ID;
const headers = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
"User-Agent": "google-api-nodejs-client/9.15.1",
"X-Goog-Api-Client": "gl-node/22.17.0",
};
// Try to load existing project via loadCodeAssist
onProgress?.("Checking for existing Cloud Code Assist project...");
const loadResponse = await fetch(
`${CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist`,
{
method: "POST",
headers,
body: JSON.stringify({
cloudaicompanionProject: envProjectId,
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
duetProject: envProjectId,
},
}),
},
);
let data: LoadCodeAssistPayload;
if (!loadResponse.ok) {
let errorPayload: unknown;
try {
errorPayload = await loadResponse.clone().json();
} catch {
errorPayload = undefined;
}
if (isVpcScAffectedUser(errorPayload)) {
data = { currentTier: { id: TIER_STANDARD } };
} else {
const errorText = await loadResponse.text();
throw new Error(
`loadCodeAssist failed: ${loadResponse.status} ${loadResponse.statusText}: ${errorText}`,
);
}
} else {
data = (await loadResponse.json()) as LoadCodeAssistPayload;
}
// If user already has a current tier and project, use it
if (data.currentTier) {
if (data.cloudaicompanionProject) {
return data.cloudaicompanionProject;
}
// User has a tier but no managed project - they need to provide one via env var
if (envProjectId) {
return envProjectId;
}
throw new Error(
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
// User needs to be onboarded - get the default tier
const tier = getDefaultTier(data.allowedTiers);
const tierId = tier?.id ?? TIER_FREE;
if (tierId !== TIER_FREE && !envProjectId) {
throw new Error(
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
onProgress?.(
"Provisioning Cloud Code Assist project (this may take a moment)...",
);
// Build onboard request - for free tier, don't include project ID (Google provisions one)
// For other tiers, include the user's project ID if available
const onboardBody: Record<string, unknown> = {
tierId,
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
},
};
if (tierId !== TIER_FREE && envProjectId) {
onboardBody.cloudaicompanionProject = envProjectId;
(onboardBody.metadata as Record<string, unknown>).duetProject =
envProjectId;
}
// Start onboarding - this returns a long-running operation
const onboardResponse = await fetch(
`${CODE_ASSIST_ENDPOINT}/v1internal:onboardUser`,
{
method: "POST",
headers,
body: JSON.stringify(onboardBody),
},
);
if (!onboardResponse.ok) {
const errorText = await onboardResponse.text();
throw new Error(
`onboardUser failed: ${onboardResponse.status} ${onboardResponse.statusText}: ${errorText}`,
);
}
let lroData = (await onboardResponse.json()) as LongRunningOperationResponse;
// If the operation isn't done yet, poll until completion
if (!lroData.done && lroData.name) {
lroData = await pollOperation(lroData.name, headers, onProgress);
}
// Try to get project ID from the response
const projectId = lroData.response?.cloudaicompanionProject?.id;
if (projectId) {
return projectId;
}
// If no project ID from onboarding, fall back to env var
if (envProjectId) {
return envProjectId;
}
throw new Error(
"Could not discover or provision a Google Cloud project. " +
"Try setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
/**
* Get user email from the access token
*/
async function getUserEmail(accessToken: string): Promise<string | undefined> {
try {
const response = await fetch(
"https://www.googleapis.com/oauth2/v1/userinfo?alt=json",
{
headers: {
Authorization: `Bearer ${accessToken}`,
},
},
);
if (response.ok) {
const data = (await response.json()) as { email?: string };
return data.email;
}
} catch {
// Ignore errors, email is optional
}
return undefined;
}
/**
* Refresh Google Cloud Code Assist token
*/
export async function refreshGoogleCloudToken(
refreshToken: string,
projectId: string,
): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
refresh_token: refreshToken,
grant_type: "refresh_token",
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Google Cloud token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
expires_in: number;
refresh_token?: string;
};
return {
refresh: data.refresh_token || refreshToken,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
projectId,
};
}
/**
* Login with Gemini CLI (Google Cloud Code Assist) OAuth
*
* @param onAuth - Callback with URL and optional instructions
* @param onProgress - Optional progress callback
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
* Races with browser callback - whichever completes first wins.
*/
export async function loginGeminiCli(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
onManualCodeInput?: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");
const server = await startCallbackServer();
let code: string | undefined;
try {
// Build authorization URL
const authParams = new URLSearchParams({
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES.join(" "),
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
access_type: "offline",
prompt: "consent",
});
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuth({
url: authUrl,
instructions: "Complete the sign-in in your browser.",
});
// Wait for the callback, racing with manual input if provided
onProgress?.("Waiting for OAuth callback...");
if (onManualCodeInput) {
// Race between browser callback and manual input
let manualInput: string | undefined;
let manualError: Error | undefined;
const manualPromise = onManualCodeInput()
.then((input) => {
manualInput = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won - verify state
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
} else if (manualInput) {
// Manual input won
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
// If still no code, wait for manual promise and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualInput) {
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
}
} else {
// Original flow: just wait for callback
const result = await server.waitForCode();
if (result?.code) {
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
}
}
if (!code) {
throw new Error("No authorization code received");
}
// Exchange code for tokens
onProgress?.("Exchanging authorization code for tokens...");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code,
grant_type: "authorization_code",
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
if (!tokenData.refresh_token) {
throw new Error("No refresh token received. Please try again.");
}
// Get user email
onProgress?.("Getting user info...");
const email = await getUserEmail(tokenData.access_token);
// Discover project
const projectId = await discoverProject(tokenData.access_token, onProgress);
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
const credentials: OAuthCredentials = {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
projectId,
email,
};
return credentials;
} finally {
server.server.close();
}
}
export const geminiCliOAuthProvider: OAuthProviderInterface = {
id: "google-gemini-cli",
name: "Google Cloud Code Assist (Gemini CLI)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGeminiCli(
callbacks.onAuth,
callbacks.onProgress,
callbacks.onManualCodeInput,
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as GeminiCredentials;
if (!creds.projectId) {
throw new Error("Google Cloud credentials missing projectId");
}
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as GeminiCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -0,0 +1,187 @@
/**
* OAuth credential management for AI providers.
*
* This module handles login, token refresh, and credential storage
* for OAuth-based providers:
* - Anthropic (Claude Pro/Max)
* - GitHub Copilot
* - Google Cloud Code Assist (Gemini CLI)
* - Antigravity (Gemini 3, Claude, GPT-OSS via Google Cloud)
*/
// Anthropic
export {
anthropicOAuthProvider,
loginAnthropic,
refreshAnthropicToken,
} from "./anthropic.js";
// GitHub Copilot
export {
getGitHubCopilotBaseUrl,
githubCopilotOAuthProvider,
loginGitHubCopilot,
normalizeDomain,
refreshGitHubCopilotToken,
} from "./github-copilot.js";
// Google Antigravity
export {
antigravityOAuthProvider,
loginAntigravity,
refreshAntigravityToken,
} from "./google-antigravity.js";
// Google Gemini CLI
export {
geminiCliOAuthProvider,
loginGeminiCli,
refreshGoogleCloudToken,
} from "./google-gemini-cli.js";
// OpenAI Codex (ChatGPT OAuth)
export {
loginOpenAICodex,
openaiCodexOAuthProvider,
refreshOpenAICodexToken,
} from "./openai-codex.js";
export * from "./types.js";
// ============================================================================
// Provider Registry
// ============================================================================
import { anthropicOAuthProvider } from "./anthropic.js";
import { githubCopilotOAuthProvider } from "./github-copilot.js";
import { antigravityOAuthProvider } from "./google-antigravity.js";
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
import { openaiCodexOAuthProvider } from "./openai-codex.js";
import type {
OAuthCredentials,
OAuthProviderId,
OAuthProviderInfo,
OAuthProviderInterface,
} from "./types.js";
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
anthropicOAuthProvider,
githubCopilotOAuthProvider,
geminiCliOAuthProvider,
antigravityOAuthProvider,
openaiCodexOAuthProvider,
];
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
);
/**
* Get an OAuth provider by ID
*/
export function getOAuthProvider(
id: OAuthProviderId,
): OAuthProviderInterface | undefined {
return oauthProviderRegistry.get(id);
}
/**
* Register a custom OAuth provider
*/
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
oauthProviderRegistry.set(provider.id, provider);
}
/**
* Unregister an OAuth provider.
*
* If the provider is built-in, restores the built-in implementation.
* Custom providers are removed completely.
*/
export function unregisterOAuthProvider(id: string): void {
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find(
(provider) => provider.id === id,
);
if (builtInProvider) {
oauthProviderRegistry.set(id, builtInProvider);
return;
}
oauthProviderRegistry.delete(id);
}
/**
* Reset OAuth providers to built-ins.
*/
export function resetOAuthProviders(): void {
oauthProviderRegistry.clear();
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
oauthProviderRegistry.set(provider.id, provider);
}
}
/**
* Get all registered OAuth providers
*/
export function getOAuthProviders(): OAuthProviderInterface[] {
return Array.from(oauthProviderRegistry.values());
}
/**
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
*/
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
return getOAuthProviders().map((p) => ({
id: p.id,
name: p.name,
available: true,
}));
}
// ============================================================================
// High-level API (uses provider registry)
// ============================================================================
/**
* Refresh token for any OAuth provider.
* @deprecated Use getOAuthProvider(id).refreshToken() instead
*/
export async function refreshOAuthToken(
providerId: OAuthProviderId,
credentials: OAuthCredentials,
): Promise<OAuthCredentials> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
return provider.refreshToken(credentials);
}
/**
* Get API key for a provider from OAuth credentials.
* Automatically refreshes expired tokens.
*
* @returns API key string and updated credentials, or null if no credentials
* @throws Error if refresh fails
*/
export async function getOAuthApiKey(
providerId: OAuthProviderId,
credentials: Record<string, OAuthCredentials>,
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
let creds = credentials[providerId];
if (!creds) {
return null;
}
// Refresh if expired
if (Date.now() >= creds.expires) {
try {
creds = await provider.refreshToken(creds);
} catch (_error) {
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
}
}
const apiKey = provider.getApiKey(creds);
return { newCredentials: creds, apiKey };
}

View file

@ -0,0 +1,499 @@
/**
* OpenAI Codex (ChatGPT OAuth) flow
*
* NOTE: This module uses Node.js crypto and http for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
let _randomBytes: typeof import("node:crypto").randomBytes | null = null;
let _http: typeof import("node:http") | null = null;
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
import("node:crypto").then((m) => {
_randomBytes = m.randomBytes;
});
import("node:http").then((m) => {
_http = m;
});
}
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthPrompt,
OAuthProviderInterface,
} from "./types.js";
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
const TOKEN_URL = "https://auth.openai.com/oauth/token";
const REDIRECT_URI = "http://localhost:1455/auth/callback";
const SCOPE = "openid profile email offline_access";
const JWT_CLAIM_PATH = "https://api.openai.com/auth";
const SUCCESS_HTML = `<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Authentication successful</title>
</head>
<body>
<p>Authentication successful. Return to your terminal to continue.</p>
</body>
</html>`;
type TokenSuccess = {
type: "success";
access: string;
refresh: string;
expires: number;
};
type TokenFailure = { type: "failed" };
type TokenResult = TokenSuccess | TokenFailure;
type JwtPayload = {
[JWT_CLAIM_PATH]?: {
chatgpt_account_id?: string;
};
[key: string]: unknown;
};
function createState(): string {
if (!_randomBytes) {
throw new Error(
"OpenAI Codex OAuth is only available in Node.js environments",
);
}
return _randomBytes(16).toString("hex");
}
function parseAuthorizationInput(input: string): {
code?: string;
state?: string;
} {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// not a URL
}
if (value.includes("#")) {
const [code, state] = value.split("#", 2);
return { code, state };
}
if (value.includes("code=")) {
const params = new URLSearchParams(value);
return {
code: params.get("code") ?? undefined,
state: params.get("state") ?? undefined,
};
}
return { code: value };
}
function decodeJwt(token: string): JwtPayload | null {
try {
const parts = token.split(".");
if (parts.length !== 3) return null;
const payload = parts[1] ?? "";
const decoded = atob(payload);
return JSON.parse(decoded) as JwtPayload;
} catch {
return null;
}
}
async function exchangeAuthorizationCode(
code: string,
verifier: string,
redirectUri: string = REDIRECT_URI,
): Promise<TokenResult> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code,
code_verifier: verifier,
redirect_uri: redirectUri,
}),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
console.error("[openai-codex] code->token failed:", response.status, text);
return { type: "failed" };
}
const json = (await response.json()) as {
access_token?: string;
refresh_token?: string;
expires_in?: number;
};
if (
!json.access_token ||
!json.refresh_token ||
typeof json.expires_in !== "number"
) {
console.error("[openai-codex] token response missing fields:", json);
return { type: "failed" };
}
return {
type: "success",
access: json.access_token,
refresh: json.refresh_token,
expires: Date.now() + json.expires_in * 1000,
};
}
async function refreshAccessToken(refreshToken: string): Promise<TokenResult> {
try {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
grant_type: "refresh_token",
refresh_token: refreshToken,
client_id: CLIENT_ID,
}),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
console.error(
"[openai-codex] Token refresh failed:",
response.status,
text,
);
return { type: "failed" };
}
const json = (await response.json()) as {
access_token?: string;
refresh_token?: string;
expires_in?: number;
};
if (
!json.access_token ||
!json.refresh_token ||
typeof json.expires_in !== "number"
) {
console.error(
"[openai-codex] Token refresh response missing fields:",
json,
);
return { type: "failed" };
}
return {
type: "success",
access: json.access_token,
refresh: json.refresh_token,
expires: Date.now() + json.expires_in * 1000,
};
} catch (error) {
console.error("[openai-codex] Token refresh error:", error);
return { type: "failed" };
}
}
async function createAuthorizationFlow(
originator: string = "pi",
): Promise<{ verifier: string; state: string; url: string }> {
const { verifier, challenge } = await generatePKCE();
const state = createState();
const url = new URL(AUTHORIZE_URL);
url.searchParams.set("response_type", "code");
url.searchParams.set("client_id", CLIENT_ID);
url.searchParams.set("redirect_uri", REDIRECT_URI);
url.searchParams.set("scope", SCOPE);
url.searchParams.set("code_challenge", challenge);
url.searchParams.set("code_challenge_method", "S256");
url.searchParams.set("state", state);
url.searchParams.set("id_token_add_organizations", "true");
url.searchParams.set("codex_cli_simplified_flow", "true");
url.searchParams.set("originator", originator);
return { verifier, state, url: url.toString() };
}
type OAuthServerInfo = {
close: () => void;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string } | null>;
};
function startLocalOAuthServer(state: string): Promise<OAuthServerInfo> {
if (!_http) {
throw new Error(
"OpenAI Codex OAuth is only available in Node.js environments",
);
}
let lastCode: string | null = null;
let cancelled = false;
const server = _http.createServer((req, res) => {
try {
const url = new URL(req.url || "", "http://localhost");
if (url.pathname !== "/auth/callback") {
res.statusCode = 404;
res.end("Not found");
return;
}
if (url.searchParams.get("state") !== state) {
res.statusCode = 400;
res.end("State mismatch");
return;
}
const code = url.searchParams.get("code");
if (!code) {
res.statusCode = 400;
res.end("Missing authorization code");
return;
}
res.statusCode = 200;
res.setHeader("Content-Type", "text/html; charset=utf-8");
res.end(SUCCESS_HTML);
lastCode = code;
} catch {
res.statusCode = 500;
res.end("Internal error");
}
});
return new Promise((resolve) => {
server
.listen(1455, "127.0.0.1", () => {
resolve({
close: () => server.close(),
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
for (let i = 0; i < 600; i += 1) {
if (lastCode) return { code: lastCode };
if (cancelled) return null;
await sleep();
}
return null;
},
});
})
.on("error", (err: NodeJS.ErrnoException) => {
console.error(
"[openai-codex] Failed to bind http://127.0.0.1:1455 (",
err.code,
") Falling back to manual paste.",
);
resolve({
close: () => {
try {
server.close();
} catch {
// ignore
}
},
cancelWait: () => {},
waitForCode: async () => null,
});
});
});
}
function getAccountId(accessToken: string): string | null {
const payload = decodeJwt(accessToken);
const auth = payload?.[JWT_CLAIM_PATH];
const accountId = auth?.chatgpt_account_id;
return typeof accountId === "string" && accountId.length > 0
? accountId
: null;
}
/**
* Login with OpenAI Codex OAuth
*
* @param options.onAuth - Called with URL and instructions when auth starts
* @param options.onPrompt - Called to prompt user for manual code paste (fallback if no onManualCodeInput)
* @param options.onProgress - Optional progress messages
* @param options.onManualCodeInput - Optional promise that resolves with user-pasted code.
* Races with browser callback - whichever completes first wins.
* Useful for showing paste input immediately alongside browser flow.
* @param options.originator - OAuth originator parameter (defaults to "pi")
*/
export async function loginOpenAICodex(options: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
originator?: string;
}): Promise<OAuthCredentials> {
const { verifier, state, url } = await createAuthorizationFlow(
options.originator,
);
const server = await startLocalOAuthServer(state);
options.onAuth({
url,
instructions: "A browser window should open. Complete login to finish.",
});
let code: string | undefined;
try {
if (options.onManualCodeInput) {
// Race between browser callback and manual input
let manualCode: string | undefined;
let manualError: Error | undefined;
const manualPromise = options
.onManualCodeInput()
.then((input) => {
manualCode = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won
code = result.code;
} else if (manualCode) {
// Manual input won (or callback timed out and user had entered code)
const parsed = parseAuthorizationInput(manualCode);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
// If still no code, wait for manual promise to complete and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualCode) {
const parsed = parseAuthorizationInput(manualCode);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
}
} else {
// Original flow: wait for callback, then prompt if needed
const result = await server.waitForCode();
if (result?.code) {
code = result.code;
}
}
// Fallback to onPrompt if still no code
if (!code) {
const input = await options.onPrompt({
message: "Paste the authorization code (or full redirect URL):",
});
const parsed = parseAuthorizationInput(input);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
if (!code) {
throw new Error("Missing authorization code");
}
const tokenResult = await exchangeAuthorizationCode(code, verifier);
if (tokenResult.type !== "success") {
throw new Error("Token exchange failed");
}
const accountId = getAccountId(tokenResult.access);
if (!accountId) {
throw new Error("Failed to extract accountId from token");
}
return {
access: tokenResult.access,
refresh: tokenResult.refresh,
expires: tokenResult.expires,
accountId,
};
} finally {
server.close();
}
}
/**
* Refresh OpenAI Codex OAuth token
*/
export async function refreshOpenAICodexToken(
refreshToken: string,
): Promise<OAuthCredentials> {
const result = await refreshAccessToken(refreshToken);
if (result.type !== "success") {
throw new Error("Failed to refresh OpenAI Codex token");
}
const accountId = getAccountId(result.access);
if (!accountId) {
throw new Error("Failed to extract accountId from token");
}
return {
access: result.access,
refresh: result.refresh,
expires: result.expires,
accountId,
};
}
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
id: "openai-codex",
name: "ChatGPT Plus/Pro (Codex Subscription)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshOpenAICodexToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -0,0 +1,37 @@
/**
* PKCE utilities using Web Crypto API.
* Works in both Node.js 20+ and browsers.
*/
/**
* Encode bytes as base64url string.
*/
function base64urlEncode(bytes: Uint8Array): string {
let binary = "";
for (const byte of bytes) {
binary += String.fromCharCode(byte);
}
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
}
/**
* Generate PKCE code verifier and challenge.
* Uses Web Crypto API for cross-platform compatibility.
*/
export async function generatePKCE(): Promise<{
verifier: string;
challenge: string;
}> {
// Generate random verifier
const verifierBytes = new Uint8Array(32);
crypto.getRandomValues(verifierBytes);
const verifier = base64urlEncode(verifierBytes);
// Compute SHA-256 challenge
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
return { verifier, challenge };
}

View file

@ -0,0 +1,62 @@
import type { Api, Model } from "../../types.js";
export type OAuthCredentials = {
refresh: string;
access: string;
expires: number;
[key: string]: unknown;
};
export type OAuthProviderId = string;
/** @deprecated Use OAuthProviderId instead */
export type OAuthProvider = OAuthProviderId;
export type OAuthPrompt = {
message: string;
placeholder?: string;
allowEmpty?: boolean;
};
export type OAuthAuthInfo = {
url: string;
instructions?: string;
};
export interface OAuthLoginCallbacks {
onAuth: (info: OAuthAuthInfo) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
signal?: AbortSignal;
}
export interface OAuthProviderInterface {
readonly id: OAuthProviderId;
readonly name: string;
/** Run the login flow, return credentials to persist */
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
/** Whether login uses a local callback server and supports manual code input. */
usesCallbackServer?: boolean;
/** Refresh expired credentials, return updated credentials to persist */
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
/** Convert credentials to API key string for the provider */
getApiKey(credentials: OAuthCredentials): string;
/** Optional: modify models for this provider (e.g., update baseUrl) */
modifyModels?(
models: Model<Api>[],
credentials: OAuthCredentials,
): Model<Api>[];
}
/** @deprecated Use OAuthProviderInterface instead */
export interface OAuthProviderInfo {
id: OAuthProviderId;
name: string;
available: boolean;
}

View file

@ -0,0 +1,127 @@
import type { AssistantMessage } from "../types.js";
/**
* Regex patterns to detect context overflow errors from different providers.
*
* These patterns match error messages returned when the input exceeds
* the model's context window.
*
* Provider-specific patterns (with example error messages):
*
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
* - OpenAI: "Your input exceeds the context window of this model"
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
* - Groq: "Please reduce the length of the messages or completion"
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
* - llama.cpp: "the request exceeds the available context size, try increasing it"
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
* - MiniMax: "invalid params, context window exceeds limit"
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
* - Ollama: Silently truncates input - not detectable via error message
*/
const OVERFLOW_PATTERNS = [
/prompt is too long/i, // Anthropic
/input is too long for requested model/i, // Amazon Bedrock
/exceeds the context window/i, // OpenAI (Completions & Responses API)
/input token count.*exceeds the maximum/i, // Google (Gemini)
/maximum prompt length is \d+/i, // xAI (Grok)
/reduce the length of the messages/i, // Groq
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
/exceeds the limit of \d+/i, // GitHub Copilot
/exceeds the available context size/i, // llama.cpp server
/greater than the context length/i, // LM Studio
/context window exceeds limit/i, // MiniMax
/exceeded model token limit/i, // Kimi For Coding
/too large for model with \d+ maximum context length/i, // Mistral
/context[_ ]length[_ ]exceeded/i, // Generic fallback
/too many tokens/i, // Generic fallback
/token limit exceeded/i, // Generic fallback
];
/**
* Check if an assistant message represents a context overflow error.
*
* This handles two cases:
* 1. Error-based overflow: Most providers return stopReason "error" with a
* specific error message pattern.
* 2. Silent overflow: Some providers accept overflow requests and return
* successfully. For these, we check if usage.input exceeds the context window.
*
* ## Reliability by Provider
*
* **Reliable detection (returns error with detectable message):**
* - Anthropic: "prompt is too long: X tokens > Y maximum"
* - OpenAI (Completions & Responses): "exceeds the context window"
* - Google Gemini: "input token count exceeds the maximum"
* - xAI (Grok): "maximum prompt length is X but request contains Y"
* - Groq: "reduce the length of the messages"
* - Cerebras: 400/413 status code (no body)
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
* - OpenRouter (all backends): "maximum context length is X tokens"
* - llama.cpp: "exceeds the available context size"
* - LM Studio: "greater than the context length"
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
*
* **Unreliable detection:**
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
* The response will have usage.input < expected, but we don't know the expected value.
*
* ## Custom Providers
*
* If you've added custom models via settings.json, this function may not detect
* overflow errors from those providers. To add support:
*
* 1. Send a request that exceeds the model's context window
* 2. Check the errorMessage in the response
* 3. Create a regex pattern that matches the error
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
* check the errorMessage yourself before calling this function
*
* @param message - The assistant message to check
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
* @returns true if the message indicates a context overflow
*/
export function isContextOverflow(
message: AssistantMessage,
contextWindow?: number,
): boolean {
// Case 1: Check error message patterns
if (message.stopReason === "error" && message.errorMessage) {
// Check known patterns
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
return true;
}
// Cerebras returns 400/413 with no body for context overflow
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
if (
/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)
) {
return true;
}
}
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
if (contextWindow && message.stopReason === "stop") {
const inputTokens = message.usage.input + message.usage.cacheRead;
if (inputTokens > contextWindow) {
return true;
}
}
return false;
}
/**
* Get the overflow patterns for testing purposes.
*/
export function getOverflowPatterns(): RegExp[] {
return [...OVERFLOW_PATTERNS];
}

View file

@ -0,0 +1,28 @@
/**
* Removes unpaired Unicode surrogate characters from a string.
*
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
* or vice versa) cause JSON serialization errors in many API providers.
*
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
* surrogates and will NOT be affected by this function.
*
* @param text - The text to sanitize
* @returns The sanitized text with unpaired surrogates removed
*
* @example
* // Valid emoji (properly paired surrogates) are preserved
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
*
* // Unpaired high surrogate is removed
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
*/
export function sanitizeSurrogates(text: string): string {
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
return text.replace(
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
"",
);
}

View file

@ -0,0 +1,24 @@
import { type TUnsafe, Type } from "@sinclair/typebox";
/**
* Creates a string enum schema compatible with Google's API and other providers
* that don't support anyOf/const patterns.
*
* @example
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
* description: "The operation to perform"
* });
*
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
*/
export function StringEnum<T extends readonly string[]>(
values: T,
options?: { description?: string; default?: T[number] },
): TUnsafe<T[number]> {
return Type.Unsafe<T[number]>({
type: "string",
enum: values as any,
...(options?.description && { description: options.description }),
...(options?.default && { default: options.default }),
});
}

View file

@ -0,0 +1,88 @@
import AjvModule from "ajv";
import addFormatsModule from "ajv-formats";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
const addFormats = (addFormatsModule as any).default || addFormatsModule;
import type { Tool, ToolCall } from "../types.js";
// Detect if we're in a browser extension environment with strict CSP
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
const isBrowserExtension =
typeof globalThis !== "undefined" &&
(globalThis as any).chrome?.runtime?.id !== undefined;
// Create a singleton AJV instance with formats (only if not in browser extension)
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
let ajv: any = null;
if (!isBrowserExtension) {
try {
ajv = new Ajv({
allErrors: true,
strict: false,
coerceTypes: true,
});
addFormats(ajv);
} catch (_e) {
// AJV initialization failed (likely CSP restriction)
console.warn("AJV validation disabled due to CSP restrictions");
}
}
/**
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
* @param tools Array of tool definitions
* @param toolCall The tool call from the LLM
* @returns The validated arguments
* @throws Error if tool is not found or validation fails
*/
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
const tool = tools.find((t) => t.name === toolCall.name);
if (!tool) {
throw new Error(`Tool "${toolCall.name}" not found`);
}
return validateToolArguments(tool, toolCall);
}
/**
* Validates tool call arguments against the tool's TypeBox schema
* @param tool The tool definition with TypeBox schema
* @param toolCall The tool call from the LLM
* @returns The validated (and potentially coerced) arguments
* @throws Error with formatted message if validation fails
*/
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
if (!ajv || isBrowserExtension) {
// Trust the LLM's output without validation
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
return toolCall.arguments;
}
// Compile the schema
const validate = ajv.compile(tool.parameters);
// Clone arguments so AJV can safely mutate for type coercion
const args = structuredClone(toolCall.arguments);
// Validate the arguments (AJV mutates args in-place for type coercion)
if (validate(args)) {
return args;
}
// Format validation errors nicely
const errors =
validate.errors
?.map((err: any) => {
const path = err.instancePath
? err.instancePath.substring(1)
: err.params.missingProperty || "root";
return ` - ${path}: ${err.message}`;
})
.join("\n") || "Unknown validation error";
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
throw new Error(errorMessage);
}

View file

@ -0,0 +1,339 @@
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { complete, stream } from "../src/stream.js";
import type { Api, Context, Model, StreamOptions } from "../src/types.js";
type StreamOptionsWithExtras = StreamOptions & Record<string, unknown>;
import {
hasAzureOpenAICredentials,
resolveAzureDeploymentName,
} from "./azure-utils.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const [geminiCliToken, openaiCodexToken] = await Promise.all([
resolveApiKey("google-gemini-cli"),
resolveApiKey("openai-codex"),
]);
async function testAbortSignal<TApi extends Api>(
llm: Model<TApi>,
options: StreamOptionsWithExtras = {},
) {
const context: Context = {
messages: [
{
role: "user",
content:
"What is 15 + 27? Think step by step. Then list 50 first names.",
timestamp: Date.now(),
},
],
systemPrompt: "You are a helpful assistant.",
};
let abortFired = false;
let text = "";
const controller = new AbortController();
const response = await stream(llm, context, {
...options,
signal: controller.signal,
});
for await (const event of response) {
if (abortFired) return;
if (event.type === "text_delta" || event.type === "thinking_delta") {
text += event.delta;
}
if (text.length >= 50) {
controller.abort();
abortFired = true;
}
}
const msg = await response.result();
// If we get here without throwing, the abort didn't work
expect(msg.stopReason).toBe("aborted");
expect(msg.content.length).toBeGreaterThan(0);
context.messages.push(msg);
context.messages.push({
role: "user",
content: "Please continue, but only generate 5 names.",
timestamp: Date.now(),
});
const followUp = await complete(llm, context, options);
expect(followUp.stopReason).toBe("stop");
expect(followUp.content.length).toBeGreaterThan(0);
}
async function testImmediateAbort<TApi extends Api>(
llm: Model<TApi>,
options: StreamOptionsWithExtras = {},
) {
const controller = new AbortController();
controller.abort();
const context: Context = {
messages: [{ role: "user", content: "Hello", timestamp: Date.now() }],
};
const response = await complete(llm, context, {
...options,
signal: controller.signal,
});
expect(response.stopReason).toBe("aborted");
}
async function testAbortThenNewMessage<TApi extends Api>(
llm: Model<TApi>,
options: StreamOptionsWithExtras = {},
) {
// First request: abort immediately before any response content arrives
const controller = new AbortController();
controller.abort();
const context: Context = {
messages: [
{ role: "user", content: "Hello, how are you?", timestamp: Date.now() },
],
};
const abortedResponse = await complete(llm, context, {
...options,
signal: controller.signal,
});
expect(abortedResponse.stopReason).toBe("aborted");
// The aborted message has empty content since we aborted before anything arrived
expect(abortedResponse.content.length).toBe(0);
// Add the aborted assistant message to context (this is what happens in the real coding agent)
context.messages.push(abortedResponse);
// Second request: send a new message - this should work even with the aborted message in context
context.messages.push({
role: "user",
content: "What is 2 + 2?",
timestamp: Date.now(),
});
const followUp = await complete(llm, context, options);
expect(followUp.stopReason).toBe("stop");
expect(followUp.content.length).toBeGreaterThan(0);
}
describe("AI Providers Abort Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
const llm = getModel("google", "gemini-2.5-flash");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm, { thinking: { enabled: true } });
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm, { thinking: { enabled: true } });
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Completions Provider Abort",
() => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
)!;
void _compat;
const llm: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
};
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
},
);
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Responses Provider Abort",
() => {
const llm = getModel("openai", "gpt-5-mini");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
},
);
describe.skipIf(!hasAzureOpenAICredentials())(
"Azure OpenAI Responses Provider Abort",
() => {
const llm = getModel("azure-openai-responses", "gpt-4o-mini");
const azureDeploymentName = resolveAzureDeploymentName(llm.id);
const azureOptions = azureDeploymentName ? { azureDeploymentName } : {};
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm, azureOptions);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm, azureOptions);
});
},
);
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)(
"Anthropic Provider Abort",
() => {
const llm = getModel("anthropic", "claude-opus-4-1-20250805");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm, {
thinkingEnabled: true,
thinkingBudgetTokens: 2048,
});
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm, {
thinkingEnabled: true,
thinkingBudgetTokens: 2048,
});
});
},
);
describe.skipIf(!process.env.MISTRAL_API_KEY)(
"Mistral Provider Abort",
() => {
const llm = getModel("mistral", "devstral-medium-latest");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
},
);
describe.skipIf(!process.env.MINIMAX_API_KEY)(
"MiniMax Provider Abort",
() => {
const llm = getModel("minimax", "MiniMax-M2.1");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
},
);
describe.skipIf(!process.env.KIMI_API_KEY)(
"Kimi For Coding Provider Abort",
() => {
const llm = getModel("kimi-coding", "kimi-k2-thinking");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
},
);
describe.skipIf(!process.env.AI_GATEWAY_API_KEY)(
"Vercel AI Gateway Provider Abort",
() => {
const llm = getModel("vercel-ai-gateway", "google/gemini-2.5-flash");
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm);
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
},
);
// Google Gemini CLI / Antigravity share the same provider, so one test covers both
describe("Google Gemini CLI Provider Abort", () => {
it.skipIf(!geminiCliToken)(
"should abort mid-stream",
{ retry: 3 },
async () => {
const llm = getModel("google-gemini-cli", "gemini-2.5-flash");
await testAbortSignal(llm, { apiKey: geminiCliToken });
},
);
it.skipIf(!geminiCliToken)(
"should handle immediate abort",
{ retry: 3 },
async () => {
const llm = getModel("google-gemini-cli", "gemini-2.5-flash");
await testImmediateAbort(llm, { apiKey: geminiCliToken });
},
);
});
describe("OpenAI Codex Provider Abort", () => {
it.skipIf(!openaiCodexToken)(
"should abort mid-stream",
{ retry: 3 },
async () => {
const llm = getModel("openai-codex", "gpt-5.2-codex");
await testAbortSignal(llm, { apiKey: openaiCodexToken });
},
);
it.skipIf(!openaiCodexToken)(
"should handle immediate abort",
{ retry: 3 },
async () => {
const llm = getModel("openai-codex", "gpt-5.2-codex");
await testImmediateAbort(llm, { apiKey: openaiCodexToken });
},
);
});
describe.skipIf(!hasBedrockCredentials())(
"Amazon Bedrock Provider Abort",
() => {
const llm = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
it("should abort mid-stream", { retry: 3 }, async () => {
await testAbortSignal(llm, { reasoning: "medium" });
});
it("should handle immediate abort", { retry: 3 }, async () => {
await testImmediateAbort(llm);
});
it("should handle abort then new message", { retry: 3 }, async () => {
await testAbortThenNewMessage(llm);
});
},
);
});

View file

@ -0,0 +1,217 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { stream } from "../src/stream.js";
import type { Context, Tool } from "../src/types.js";
import { resolveApiKey } from "./oauth.js";
const oauthToken = await resolveApiKey("anthropic");
/**
* Tests for Anthropic OAuth tool name normalization.
*
* When using Claude Code OAuth, tool names must match CC's canonical casing.
* The normalization should:
* 1. Convert tool names that match CC tools (case-insensitive) to CC casing on outbound
* 2. Convert tool names back to the original casing on inbound
*
* This is a simple case-insensitive lookup, NOT a mapping of different names.
* e.g., "todowrite" -> "TodoWrite" -> "todowrite" (round-trip works)
*
* The old `find -> Glob` mapping was WRONG because:
* - Outbound: "find" -> "Glob"
* - Inbound: "Glob" -> ??? (no tool named "glob" in context.tools, only "find")
* - Result: tool call has name "Glob" but no tool exists with that name
*/
describe.skipIf(!oauthToken)("Anthropic OAuth tool name normalization", () => {
const model = getModel("anthropic", "claude-sonnet-4-20250514");
it("should normalize user-defined tool matching CC name (todowrite -> TodoWrite -> todowrite)", async () => {
// User defines a tool named "todowrite" (lowercase)
// CC has "TodoWrite" - this should round-trip correctly
const todoTool: Tool = {
name: "todowrite",
description: "Write a todo item",
parameters: Type.Object({
task: Type.String({ description: "The task to add" }),
}),
};
const context: Context = {
systemPrompt:
"You are a helpful assistant. Use the todowrite tool when asked to add todos.",
messages: [
{
role: "user",
content: "Add a todo: buy milk. Use the todowrite tool.",
timestamp: Date.now(),
},
],
tools: [todoTool],
};
const s = stream(model, context, { apiKey: oauthToken });
let toolCallName: string | undefined;
for await (const event of s) {
if (event.type === "toolcall_end") {
const toolCall = event.partial.content[event.contentIndex];
if (toolCall.type === "toolCall") {
toolCallName = toolCall.name;
}
}
}
const response = await s.result();
expect(response.stopReason, `Error: ${response.errorMessage}`).toBe(
"toolUse",
);
// The tool call should come back with the ORIGINAL name "todowrite", not "TodoWrite"
expect(toolCallName).toBe("todowrite");
});
it("should handle pi's built-in tools (read, write, edit, bash)", async () => {
// Pi's tools use lowercase names, CC uses PascalCase
const readTool: Tool = {
name: "read",
description: "Read a file",
parameters: Type.Object({
path: Type.String({ description: "File path" }),
}),
};
const context: Context = {
systemPrompt:
"You are a helpful assistant. Use the read tool to read files.",
messages: [
{
role: "user",
content: "Read the file /tmp/test.txt using the read tool.",
timestamp: Date.now(),
},
],
tools: [readTool],
};
const s = stream(model, context, { apiKey: oauthToken });
let toolCallName: string | undefined;
for await (const event of s) {
if (event.type === "toolcall_end") {
const toolCall = event.partial.content[event.contentIndex];
if (toolCall.type === "toolCall") {
toolCallName = toolCall.name;
}
}
}
const response = await s.result();
expect(response.stopReason, `Error: ${response.errorMessage}`).toBe(
"toolUse",
);
// The tool call should come back with the ORIGINAL name "read", not "Read"
expect(toolCallName).toBe("read");
});
it("should NOT map find to Glob - find is not a CC tool name", async () => {
// Pi has a "find" tool, CC has "Glob" - these are DIFFERENT tools
// The old code incorrectly mapped find -> Glob, which broke the round-trip
// because there's no tool named "glob" in context.tools
const findTool: Tool = {
name: "find",
description: "Find files by pattern",
parameters: Type.Object({
pattern: Type.String({ description: "Glob pattern" }),
}),
};
const context: Context = {
systemPrompt:
"You are a helpful assistant. Use the find tool to search for files.",
messages: [
{
role: "user",
content: "Find all .ts files using the find tool.",
timestamp: Date.now(),
},
],
tools: [findTool],
};
const s = stream(model, context, { apiKey: oauthToken });
let toolCallName: string | undefined;
for await (const event of s) {
if (event.type === "toolcall_end") {
const toolCall = event.partial.content[event.contentIndex];
if (toolCall.type === "toolCall") {
toolCallName = toolCall.name;
}
}
}
const response = await s.result();
expect(response.stopReason, `Error: ${response.errorMessage}`).toBe(
"toolUse",
);
// With the BROKEN find -> Glob mapping:
// - Sent as "Glob" to Anthropic
// - Received back as "Glob"
// - fromClaudeCodeName("Glob", tools) looks for tool.name.toLowerCase() === "glob"
// - No match (tool is named "find"), returns "Glob"
// - Test fails: toolCallName is "Glob" instead of "find"
//
// With the CORRECT implementation (no find->Glob mapping):
// - Sent as "find" to Anthropic (no CC tool named "Find")
// - Received back as "find"
// - Test passes: toolCallName is "find"
expect(toolCallName).toBe("find");
});
it("should handle custom tools that don't match any CC tool names", async () => {
// A completely custom tool should pass through unchanged
const customTool: Tool = {
name: "my_custom_tool",
description: "A custom tool",
parameters: Type.Object({
input: Type.String({ description: "Input value" }),
}),
};
const context: Context = {
systemPrompt:
"You are a helpful assistant. Use my_custom_tool when asked.",
messages: [
{
role: "user",
content: "Use my_custom_tool with input 'hello'.",
timestamp: Date.now(),
},
],
tools: [customTool],
};
const s = stream(model, context, { apiKey: oauthToken });
let toolCallName: string | undefined;
for await (const event of s) {
if (event.type === "toolcall_end") {
const toolCall = event.partial.content[event.contentIndex];
if (toolCall.type === "toolCall") {
toolCallName = toolCall.name;
}
}
}
const response = await s.result();
expect(response.stopReason, `Error: ${response.errorMessage}`).toBe(
"toolUse",
);
// Custom tool names should pass through unchanged
expect(toolCallName).toBe("my_custom_tool");
});
});

View file

@ -0,0 +1,34 @@
/**
* Utility functions for Azure OpenAI tests
*/
function parseDeploymentNameMap(
value: string | undefined,
): Map<string, string> {
const map = new Map<string, string>();
if (!value) return map;
for (const entry of value.split(",")) {
const trimmed = entry.trim();
if (!trimmed) continue;
const [modelId, deploymentName] = trimmed.split("=", 2);
if (!modelId || !deploymentName) continue;
map.set(modelId.trim(), deploymentName.trim());
}
return map;
}
export function hasAzureOpenAICredentials(): boolean {
const hasKey = !!process.env.AZURE_OPENAI_API_KEY;
const hasBaseUrl = !!(
process.env.AZURE_OPENAI_BASE_URL || process.env.AZURE_OPENAI_RESOURCE_NAME
);
return hasKey && hasBaseUrl;
}
export function resolveAzureDeploymentName(
modelId: string,
): string | undefined {
const mapValue = process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP;
if (!mapValue) return undefined;
return parseDeploymentNameMap(mapValue).get(modelId);
}

View file

@ -0,0 +1,72 @@
/**
* A test suite to ensure all configured Amazon Bedrock models are usable.
*
* This is here to make sure we got correct model identifiers from models.dev and other sources.
* Because Amazon Bedrock requires cross-region inference in some models,
* plain model identifiers are not always usable and it requires tweaking of model identifiers to use cross-region inference.
* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html#inference-profiles-support-system for more details.
*
* This test suite is not enabled by default unless AWS credentials and `BEDROCK_EXTENSIVE_MODEL_TEST` environment variables are set.
* This test suite takes ~2 minutes to run. Because not all models are available in all regions,
* it's recommended to use `us-west-2` region for best coverage for running this test suite.
*
* You can run this test suite with:
* ```bash
* $ AWS_REGION=us-west-2 BEDROCK_EXTENSIVE_MODEL_TEST=1 AWS_PROFILE=... npm test -- ./test/bedrock-models.test.ts
* ```
*/
import { describe, expect, it } from "vitest";
import { getModels } from "../src/models.js";
import { complete } from "../src/stream.js";
import type { Context } from "../src/types.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
describe("Amazon Bedrock Models", () => {
const models = getModels("amazon-bedrock");
it("should get all available Bedrock models", () => {
expect(models.length).toBeGreaterThan(0);
console.log(`Found ${models.length} Bedrock models`);
});
if (hasBedrockCredentials() && process.env.BEDROCK_EXTENSIVE_MODEL_TEST) {
for (const model of models) {
it(
`should make a simple request with ${model.id}`,
{ timeout: 10_000 },
async () => {
const context: Context = {
systemPrompt: "You are a helpful assistant. Be extremely concise.",
messages: [
{
role: "user",
content: "Reply with exactly: 'OK'",
timestamp: Date.now(),
},
],
};
const response = await complete(model, context);
expect(response.role).toBe("assistant");
expect(response.content).toBeTruthy();
expect(response.content.length).toBeGreaterThan(0);
expect(
response.usage.input + response.usage.cacheRead,
).toBeGreaterThan(0);
expect(response.usage.output).toBeGreaterThan(0);
expect(response.errorMessage).toBeFalsy();
const textContent = response.content
.filter((b) => b.type === "text")
.map((b) => (b.type === "text" ? b.text : ""))
.join("")
.trim();
expect(textContent).toBeTruthy();
console.log(`${model.id}: ${textContent.substring(0, 100)}`);
},
);
}
}
});

View file

@ -0,0 +1,18 @@
/**
* Utility functions for Amazon Bedrock tests
*/
/**
* Check if any valid AWS credentials are configured for Bedrock.
* Returns true if any of the following are set:
* - AWS_PROFILE (named profile from ~/.aws/credentials)
* - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY (IAM keys)
* - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)
*/
export function hasBedrockCredentials(): boolean {
return !!(
process.env.AWS_PROFILE ||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
process.env.AWS_BEARER_TOKEN_BEDROCK
);
}

View file

@ -0,0 +1,352 @@
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { stream } from "../src/stream.js";
import type { Context } from "../src/types.js";
describe("Cache Retention (PI_CACHE_RETENTION)", () => {
const originalEnv = process.env.PI_CACHE_RETENTION;
beforeEach(() => {
delete process.env.PI_CACHE_RETENTION;
});
afterEach(() => {
if (originalEnv !== undefined) {
process.env.PI_CACHE_RETENTION = originalEnv;
} else {
delete process.env.PI_CACHE_RETENTION;
}
});
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [{ role: "user", content: "Hello", timestamp: Date.now() }],
};
describe("Anthropic Provider", () => {
it.skipIf(!process.env.ANTHROPIC_API_KEY)(
"should use default cache TTL (no ttl field) when PI_CACHE_RETENTION is not set",
async () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const s = stream(model, context, {
onPayload: (payload) => {
capturedPayload = payload;
},
});
// Consume the stream to trigger the request
for await (const _ of s) {
// Just consume
}
expect(capturedPayload).not.toBeNull();
// System prompt should have cache_control without ttl
expect(capturedPayload.system).toBeDefined();
expect(capturedPayload.system[0].cache_control).toEqual({
type: "ephemeral",
});
},
);
it.skipIf(!process.env.ANTHROPIC_API_KEY)(
"should use 1h cache TTL when PI_CACHE_RETENTION=long",
async () => {
process.env.PI_CACHE_RETENTION = "long";
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const s = stream(model, context, {
onPayload: (payload) => {
capturedPayload = payload;
},
});
// Consume the stream to trigger the request
for await (const _ of s) {
// Just consume
}
expect(capturedPayload).not.toBeNull();
// System prompt should have cache_control with ttl: "1h"
expect(capturedPayload.system).toBeDefined();
expect(capturedPayload.system[0].cache_control).toEqual({
type: "ephemeral",
ttl: "1h",
});
},
);
it("should not add ttl when baseUrl is not api.anthropic.com", async () => {
process.env.PI_CACHE_RETENTION = "long";
// Create a model with a different baseUrl (simulating a proxy)
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
const proxyModel = {
...baseModel,
baseUrl: "https://my-proxy.example.com/v1",
};
let capturedPayload: any = null;
// We can't actually make the request (no proxy), but we can verify the payload
// by using a mock or checking the logic directly
// For this test, we'll import the helper directly
// Since we can't easily test this without mocking, we'll skip the actual API call
// and just verify the helper logic works correctly
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(proxyModel, context, {
apiKey: "fake-key",
onPayload: (payload) => {
capturedPayload = payload;
},
});
// This will fail since we're using a fake key and fake proxy, but the payload should be captured
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
// The payload should have been captured before the error
if (capturedPayload) {
// System prompt should have cache_control WITHOUT ttl (proxy URL)
expect(capturedPayload.system[0].cache_control).toEqual({
type: "ephemeral",
});
}
});
it("should omit cache_control when cacheRetention is none", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
cacheRetention: "none",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.system[0].cache_control).toBeUndefined();
});
it("should add cache_control to string user messages", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
const lastMessage =
capturedPayload.messages[capturedPayload.messages.length - 1];
expect(Array.isArray(lastMessage.content)).toBe(true);
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
expect(lastBlock.cache_control).toEqual({ type: "ephemeral" });
});
it("should set 1h cache TTL when cacheRetention is long", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
cacheRetention: "long",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.system[0].cache_control).toEqual({
type: "ephemeral",
ttl: "1h",
});
});
});
describe("OpenAI Responses Provider", () => {
it.skipIf(!process.env.OPENAI_API_KEY)(
"should not set prompt_cache_retention when PI_CACHE_RETENTION is not set",
async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const s = stream(model, context, {
onPayload: (payload) => {
capturedPayload = payload;
},
});
// Consume the stream to trigger the request
for await (const _ of s) {
// Just consume
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
},
);
it.skipIf(!process.env.OPENAI_API_KEY)(
"should set prompt_cache_retention to 24h when PI_CACHE_RETENTION=long",
async () => {
process.env.PI_CACHE_RETENTION = "long";
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const s = stream(model, context, {
onPayload: (payload) => {
capturedPayload = payload;
},
});
// Consume the stream to trigger the request
for await (const _ of s) {
// Just consume
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_retention).toBe("24h");
},
);
it("should not set prompt_cache_retention when baseUrl is not api.openai.com", async () => {
process.env.PI_CACHE_RETENTION = "long";
// Create a model with a different baseUrl (simulating a proxy)
const baseModel = getModel("openai", "gpt-4o-mini");
const proxyModel = {
...baseModel,
baseUrl: "https://my-proxy.example.com/v1",
};
let capturedPayload: any = null;
const { streamOpenAIResponses } =
await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(proxyModel, context, {
apiKey: "fake-key",
onPayload: (payload) => {
capturedPayload = payload;
},
});
// This will fail since we're using a fake key and fake proxy, but the payload should be captured
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
// The payload should have been captured before the error
if (capturedPayload) {
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
}
});
it("should omit prompt_cache_key when cacheRetention is none", async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const { streamOpenAIResponses } =
await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(model, context, {
apiKey: "fake-key",
cacheRetention: "none",
sessionId: "session-1",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_key).toBeUndefined();
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
});
it("should set prompt_cache_retention when cacheRetention is long", async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const { streamOpenAIResponses } =
await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(model, context, {
apiKey: "fake-key",
cacheRetention: "long",
sessionId: "session-2",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_key).toBe("session-2");
expect(capturedPayload.prompt_cache_retention).toBe("24h");
});
});
});

View file

@ -0,0 +1,864 @@
/**
* Test context overflow error handling across providers.
*
* Context overflow occurs when the input (prompt + history) exceeds
* the model's context window. This is different from output token limits.
*
* Expected behavior: All providers should return stopReason: "error"
* with an errorMessage that indicates the context was too large,
* OR (for z.ai) return successfully with usage.input > contextWindow.
*
* The isContextOverflow() function must return true for all providers.
*/
import type { ChildProcess } from "child_process";
import { execSync, spawn } from "child_process";
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { complete } from "../src/stream.js";
import type { AssistantMessage, Context, Model, Usage } from "../src/types.js";
import { isContextOverflow } from "../src/utils/overflow.js";
import { hasAzureOpenAICredentials } from "./azure-utils.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
resolveApiKey("openai-codex"),
]);
const [githubCopilotToken, geminiCliToken, antigravityToken, openaiCodexToken] =
oauthTokens;
// Lorem ipsum paragraph for realistic token estimation
const LOREM_IPSUM = `Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. `;
// Generate a string that will exceed the context window
// Using chars/4 as token estimate (works better with varied text than repeated chars)
function generateOverflowContent(contextWindow: number): string {
const targetTokens = contextWindow + 10000; // Exceed by 10k tokens
const targetChars = targetTokens * 4 * 1.5;
const repetitions = Math.ceil(targetChars / LOREM_IPSUM.length);
return LOREM_IPSUM.repeat(repetitions);
}
interface OverflowResult {
provider: string;
model: string;
contextWindow: number;
stopReason: string;
errorMessage: string | undefined;
usage: Usage;
hasUsageData: boolean;
response: AssistantMessage;
}
async function testContextOverflow(
model: Model<any>,
apiKey: string,
): Promise<OverflowResult> {
const overflowContent = generateOverflowContent(model.contextWindow);
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [
{
role: "user",
content: overflowContent,
timestamp: Date.now(),
},
],
};
const response = await complete(model, context, { apiKey });
const hasUsageData = response.usage.input > 0 || response.usage.cacheRead > 0;
return {
provider: model.provider,
model: model.id,
contextWindow: model.contextWindow,
stopReason: response.stopReason,
errorMessage: response.errorMessage,
usage: response.usage,
hasUsageData,
response,
};
}
function logResult(result: OverflowResult) {
console.log(`\n${result.provider} / ${result.model}:`);
console.log(` contextWindow: ${result.contextWindow}`);
console.log(` stopReason: ${result.stopReason}`);
console.log(` errorMessage: ${result.errorMessage}`);
console.log(` usage: ${JSON.stringify(result.usage)}`);
console.log(` hasUsageData: ${result.hasUsageData}`);
}
// =============================================================================
// Anthropic
// Expected pattern: "prompt is too long: X tokens > Y maximum"
// =============================================================================
describe("Context overflow error handling", () => {
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic (API Key)", () => {
it("claude-3-5-haiku - should detect overflow via isContextOverflow", async () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
const result = await testContextOverflow(
model,
process.env.ANTHROPIC_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/prompt is too long/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)(
"Anthropic (OAuth)",
() => {
it("claude-sonnet-4 - should detect overflow via isContextOverflow", async () => {
const model = getModel("anthropic", "claude-sonnet-4-20250514");
const result = await testContextOverflow(
model,
process.env.ANTHROPIC_OAUTH_TOKEN!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/prompt is too long/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
},
);
// =============================================================================
// GitHub Copilot (OAuth)
// Tests both OpenAI and Anthropic models via Copilot
// =============================================================================
describe("GitHub Copilot (OAuth)", () => {
// OpenAI model via Copilot
it.skipIf(!githubCopilotToken)(
"gpt-4o - should detect overflow via isContextOverflow",
async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await testContextOverflow(model, githubCopilotToken!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/exceeds the limit of \d+/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
},
120000,
);
// Anthropic model via Copilot
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should detect overflow via isContextOverflow",
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await testContextOverflow(model, githubCopilotToken!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/exceeds the limit of \d+|input is too long/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
},
120000,
);
});
// =============================================================================
// OpenAI
// Expected pattern: "exceeds the context window"
// =============================================================================
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions", () => {
it("gpt-4o-mini - should detect overflow via isContextOverflow", async () => {
const model = { ...getModel("openai", "gpt-4o-mini") };
model.api = "openai-completions" as any;
const result = await testContextOverflow(
model,
process.env.OPENAI_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/maximum context length/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses", () => {
it("gpt-4o - should detect overflow via isContextOverflow", async () => {
const model = getModel("openai", "gpt-4o");
const result = await testContextOverflow(
model,
process.env.OPENAI_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/exceeds the context window/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
describe.skipIf(!hasAzureOpenAICredentials())(
"Azure OpenAI Responses",
() => {
it("gpt-4o-mini - should detect overflow via isContextOverflow", async () => {
const model = getModel("azure-openai-responses", "gpt-4o-mini");
const result = await testContextOverflow(
model,
process.env.AZURE_OPENAI_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/context|maximum/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
},
);
// =============================================================================
// Google
// Expected pattern: "input token count (X) exceeds the maximum"
// =============================================================================
describe.skipIf(!process.env.GEMINI_API_KEY)("Google", () => {
it("gemini-2.0-flash - should detect overflow via isContextOverflow", async () => {
const model = getModel("google", "gemini-2.0-flash");
const result = await testContextOverflow(
model,
process.env.GEMINI_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/input token count.*exceeds the maximum/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Google Gemini CLI (OAuth)
// Uses same API as Google, expects same error pattern
// =============================================================================
describe("Google Gemini CLI (OAuth)", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should detect overflow via isContextOverflow",
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await testContextOverflow(model, geminiCliToken!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/input token count.*exceeds the maximum/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
},
120000,
);
});
// =============================================================================
// Google Antigravity (OAuth)
// Tests both Gemini and Anthropic models via Antigravity
// =============================================================================
describe("Google Antigravity (OAuth)", () => {
// Gemini model
it.skipIf(!antigravityToken)(
"gemini-3-flash - should detect overflow via isContextOverflow",
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await testContextOverflow(model, antigravityToken!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/input token count.*exceeds the maximum/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
},
120000,
);
// Anthropic model via Antigravity
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should detect overflow via isContextOverflow",
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await testContextOverflow(model, antigravityToken!);
logResult(result);
expect(result.stopReason).toBe("error");
// Anthropic models return "prompt is too long" pattern
expect(result.errorMessage).toMatch(/prompt is too long/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
},
120000,
);
});
// =============================================================================
// OpenAI Codex (OAuth)
// Uses ChatGPT Plus/Pro subscription via OAuth
// =============================================================================
describe("OpenAI Codex (OAuth)", () => {
it.skipIf(!openaiCodexToken)(
"gpt-5.2-codex - should detect overflow via isContextOverflow",
async () => {
const model = getModel("openai-codex", "gpt-5.2-codex");
const result = await testContextOverflow(model, openaiCodexToken!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
},
120000,
);
});
// =============================================================================
// Amazon Bedrock
// Expected pattern: "Input is too long for requested model"
// =============================================================================
describe.skipIf(!hasBedrockCredentials())("Amazon Bedrock", () => {
it("claude-sonnet-4-5 - should detect overflow via isContextOverflow", async () => {
const model = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
const result = await testContextOverflow(model, "");
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// xAI
// Expected pattern: "maximum prompt length is X but the request contains Y"
// =============================================================================
describe.skipIf(!process.env.XAI_API_KEY)("xAI", () => {
it("grok-3-fast - should detect overflow via isContextOverflow", async () => {
const model = getModel("xai", "grok-3-fast");
const result = await testContextOverflow(model, process.env.XAI_API_KEY!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/maximum prompt length is \d+/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Groq
// Expected pattern: "reduce the length of the messages"
// =============================================================================
describe.skipIf(!process.env.GROQ_API_KEY)("Groq", () => {
it("llama-3.3-70b-versatile - should detect overflow via isContextOverflow", async () => {
const model = getModel("groq", "llama-3.3-70b-versatile");
const result = await testContextOverflow(
model,
process.env.GROQ_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(/reduce the length of the messages/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Cerebras
// Expected: 400/413 status code with no body
// =============================================================================
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras", () => {
it("qwen-3-235b - should detect overflow via isContextOverflow", async () => {
const model = getModel("cerebras", "qwen-3-235b-a22b-instruct-2507");
const result = await testContextOverflow(
model,
process.env.CEREBRAS_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
// Cerebras returns status code with no body (400, 413, or 429 for token rate limit)
expect(result.errorMessage).toMatch(/4(00|13|29).*\(no body\)/i);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Hugging Face
// Uses OpenAI-compatible Inference Router
// =============================================================================
describe.skipIf(!process.env.HF_TOKEN)("Hugging Face", () => {
it("Kimi-K2.5 - should detect overflow via isContextOverflow", async () => {
const model = getModel("huggingface", "moonshotai/Kimi-K2.5");
const result = await testContextOverflow(model, process.env.HF_TOKEN!);
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// z.ai
// Special case: Sometimes accepts overflow silently, sometimes rate limits
// Detection via usage.input > contextWindow when successful
// =============================================================================
describe.skipIf(!process.env.ZAI_API_KEY)("z.ai", () => {
it("glm-4.5-flash - should detect overflow via isContextOverflow (silent overflow or rate limit)", async () => {
const model = getModel("zai", "glm-4.5-flash");
const result = await testContextOverflow(model, process.env.ZAI_API_KEY!);
logResult(result);
// z.ai behavior is inconsistent:
// - Sometimes accepts overflow and returns successfully with usage.input > contextWindow
// - Sometimes returns rate limit error
// Either way, isContextOverflow should detect it (via usage check or we skip if rate limited)
if (result.stopReason === "stop") {
if (result.hasUsageData && result.usage.input > model.contextWindow) {
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
} else {
console.log(
" z.ai returned stop without overflow usage data, skipping overflow detection",
);
}
} else {
// Rate limited or other error - just log and pass
console.log(
" z.ai returned error (possibly rate limited), skipping overflow detection",
);
}
}, 120000);
});
// =============================================================================
// Mistral
// =============================================================================
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral", () => {
it("devstral-medium-latest - should detect overflow via isContextOverflow", async () => {
const model = getModel("mistral", "devstral-medium-latest");
const result = await testContextOverflow(
model,
process.env.MISTRAL_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/too large for model with \d+ maximum context length/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// MiniMax
// Expected pattern: TBD - need to test actual error message
// =============================================================================
describe.skipIf(!process.env.MINIMAX_API_KEY)("MiniMax", () => {
it("MiniMax-M2.1 - should detect overflow via isContextOverflow", async () => {
const model = getModel("minimax", "MiniMax-M2.1");
const result = await testContextOverflow(
model,
process.env.MINIMAX_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Kimi For Coding
// =============================================================================
describe.skipIf(!process.env.KIMI_API_KEY)("Kimi For Coding", () => {
it("kimi-k2-thinking - should detect overflow via isContextOverflow", async () => {
const model = getModel("kimi-coding", "kimi-k2-thinking");
const result = await testContextOverflow(
model,
process.env.KIMI_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Vercel AI Gateway - Unified API for multiple providers
// =============================================================================
describe.skipIf(!process.env.AI_GATEWAY_API_KEY)("Vercel AI Gateway", () => {
it("google/gemini-2.5-flash via AI Gateway - should detect overflow via isContextOverflow", async () => {
const model = getModel("vercel-ai-gateway", "google/gemini-2.5-flash");
const result = await testContextOverflow(
model,
process.env.AI_GATEWAY_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// OpenRouter - Multiple backend providers
// Expected pattern: "maximum context length is X tokens"
// =============================================================================
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter", () => {
// Anthropic backend
it("anthropic/claude-sonnet-4 via OpenRouter - should detect overflow via isContextOverflow", async () => {
const model = getModel("openrouter", "anthropic/claude-sonnet-4");
const result = await testContextOverflow(
model,
process.env.OPENROUTER_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/maximum context length is \d+ tokens/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
// DeepSeek backend
it("deepseek/deepseek-v3.2 via OpenRouter - should detect overflow via isContextOverflow", async () => {
const model = getModel("openrouter", "deepseek/deepseek-v3.2");
const result = await testContextOverflow(
model,
process.env.OPENROUTER_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/maximum context length is \d+ tokens/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
// Mistral backend
it("mistralai/mistral-large-2512 via OpenRouter - should detect overflow via isContextOverflow", async () => {
const model = getModel("openrouter", "mistralai/mistral-large-2512");
const result = await testContextOverflow(
model,
process.env.OPENROUTER_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/maximum context length is \d+ tokens/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
// Google backend
it("google/gemini-2.5-flash via OpenRouter - should detect overflow via isContextOverflow", async () => {
const model = getModel("openrouter", "google/gemini-2.5-flash");
const result = await testContextOverflow(
model,
process.env.OPENROUTER_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/maximum context length is \d+ tokens/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
// Meta/Llama backend
it("meta-llama/llama-4-maverick via OpenRouter - should detect overflow via isContextOverflow", async () => {
const model = getModel("openrouter", "meta-llama/llama-4-maverick");
const result = await testContextOverflow(
model,
process.env.OPENROUTER_API_KEY!,
);
logResult(result);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toMatch(
/maximum context length is \d+ tokens/i,
);
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// Ollama (local)
// =============================================================================
// Check if ollama is installed and local LLM tests are enabled
let ollamaInstalled = false;
if (!process.env.PI_NO_LOCAL_LLM) {
try {
execSync("which ollama", { stdio: "ignore" });
ollamaInstalled = true;
} catch {
ollamaInstalled = false;
}
}
describe.skipIf(!ollamaInstalled)("Ollama (local)", () => {
let ollamaProcess: ChildProcess | null = null;
let model: Model<"openai-completions">;
beforeAll(async () => {
// Check if model is available, if not pull it
try {
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
} catch {
console.log("Pulling gpt-oss:20b model for Ollama overflow tests...");
try {
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
} catch (_e) {
console.warn(
"Failed to pull gpt-oss:20b model, tests will be skipped",
);
return;
}
}
// Start ollama server
ollamaProcess = spawn("ollama", ["serve"], {
detached: false,
stdio: "ignore",
});
// Wait for server to be ready
await new Promise<void>((resolve) => {
const checkServer = async () => {
try {
const response = await fetch("http://localhost:11434/api/tags");
if (response.ok) {
resolve();
} else {
setTimeout(checkServer, 500);
}
} catch {
setTimeout(checkServer, 500);
}
};
setTimeout(checkServer, 1000);
});
model = {
id: "gpt-oss:20b",
api: "openai-completions",
provider: "ollama",
baseUrl: "http://localhost:11434/v1",
reasoning: true,
input: ["text"],
contextWindow: 128000,
maxTokens: 16000,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
name: "Ollama GPT-OSS 20B",
};
}, 60000);
afterAll(() => {
if (ollamaProcess) {
ollamaProcess.kill("SIGTERM");
ollamaProcess = null;
}
});
it("gpt-oss:20b - should detect overflow via isContextOverflow (ollama silently truncates)", async () => {
const result = await testContextOverflow(model, "ollama");
logResult(result);
// Ollama silently truncates input instead of erroring
// It returns stopReason "stop" with truncated usage
// We cannot detect overflow via error message, only via usage comparison
if (result.stopReason === "stop" && result.hasUsageData) {
// Ollama truncated - check if reported usage is less than what we sent
// This is a "silent overflow" - we can detect it if we know expected input size
console.log(
" Ollama silently truncated input to",
result.usage.input,
"tokens",
);
// For now, we accept this behavior - Ollama doesn't give us a way to detect overflow
} else if (result.stopReason === "error") {
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}
}, 300000); // 5 min timeout for local model
});
// =============================================================================
// LM Studio (local) - Skip if not running or local LLM tests disabled
// =============================================================================
let lmStudioRunning = false;
if (!process.env.PI_NO_LOCAL_LLM) {
try {
execSync(
"curl -s --max-time 1 http://localhost:1234/v1/models > /dev/null",
{ stdio: "ignore" },
);
lmStudioRunning = true;
} catch {
lmStudioRunning = false;
}
}
describe.skipIf(!lmStudioRunning)("LM Studio (local)", () => {
it("should detect overflow via isContextOverflow", async () => {
const model: Model<"openai-completions"> = {
id: "local-model",
api: "openai-completions",
provider: "lm-studio",
baseUrl: "http://localhost:1234/v1",
reasoning: false,
input: ["text"],
contextWindow: 8192,
maxTokens: 2048,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
name: "LM Studio Local Model",
};
const result = await testContextOverflow(model, "lm-studio");
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
// =============================================================================
// llama.cpp server (local) - Skip if not running
// =============================================================================
let llamaCppRunning = false;
try {
execSync("curl -s --max-time 1 http://localhost:8081/health > /dev/null", {
stdio: "ignore",
});
llamaCppRunning = true;
} catch {
llamaCppRunning = false;
}
describe.skipIf(!llamaCppRunning)("llama.cpp (local)", () => {
it("should detect overflow via isContextOverflow", async () => {
// Using small context (4096) to match server --ctx-size setting
const model: Model<"openai-completions"> = {
id: "local-model",
api: "openai-completions",
provider: "llama.cpp",
baseUrl: "http://localhost:8081/v1",
reasoning: false,
input: ["text"],
contextWindow: 4096,
maxTokens: 2048,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
name: "llama.cpp Local Model",
};
const result = await testContextOverflow(model, "llama.cpp");
logResult(result);
expect(result.stopReason).toBe("error");
expect(isContextOverflow(result.response, model.contextWindow)).toBe(
true,
);
}, 120000);
});
});

View file

@ -0,0 +1,568 @@
/**
* Cross-Provider Handoff Test
*
* Tests that contexts generated by one provider/model can be consumed by another.
* This catches issues like:
* - Tool call ID format incompatibilities (e.g., OpenAI Codex pipe characters)
* - Thinking block transformation issues
* - Message format incompatibilities
*
* Strategy:
* 1. beforeAll: For each provider/model, generate a "small context" (if not cached):
* - User message asking to use a tool
* - Assistant response with thinking + tool call
* - Tool result
* - Final assistant response
*
* 2. Test: For each target provider/model:
* - Concatenate ALL other contexts into one
* - Ask the model to "say hi"
* - If it fails, there's a compatibility issue
*
* Fixtures are generated fresh on each run.
*/
import { Type } from "@sinclair/typebox";
import { writeFileSync } from "fs";
import { beforeAll, describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { completeSimple, getEnvApiKey } from "../src/stream.js";
import type {
Api,
AssistantMessage,
Message,
Model,
Tool,
ToolResultMessage,
} from "../src/types.js";
import { hasAzureOpenAICredentials } from "./azure-utils.js";
import { resolveApiKey } from "./oauth.js";
// Simple tool for testing
const testToolSchema = Type.Object({
value: Type.Number({ description: "A number to double" }),
});
const testTool: Tool<typeof testToolSchema> = {
name: "double_number",
description: "Doubles a number and returns the result",
parameters: testToolSchema,
};
// Provider/model pairs to test
interface ProviderModelPair {
provider: string;
model: string;
label: string;
apiOverride?: Api;
}
const PROVIDER_MODEL_PAIRS: ProviderModelPair[] = [
// Anthropic
{
provider: "anthropic",
model: "claude-sonnet-4-5",
label: "anthropic-claude-sonnet-4-5",
},
// Google
{
provider: "google",
model: "gemini-3-flash-preview",
label: "google-gemini-3-flash-preview",
},
// OpenAI
{
provider: "openai",
model: "gpt-4o-mini",
label: "openai-completions-gpt-4o-mini",
apiOverride: "openai-completions",
},
{
provider: "openai",
model: "gpt-5-mini",
label: "openai-responses-gpt-5-mini",
},
{
provider: "azure-openai-responses",
model: "gpt-4o-mini",
label: "azure-openai-responses-gpt-4o-mini",
},
// OpenAI Codex
{
provider: "openai-codex",
model: "gpt-5.2-codex",
label: "openai-codex-gpt-5.2-codex",
},
// Google Antigravity
{
provider: "google-antigravity",
model: "gemini-3-flash",
label: "antigravity-gemini-3-flash",
},
{
provider: "google-antigravity",
model: "claude-sonnet-4-5",
label: "antigravity-claude-sonnet-4-5",
},
// GitHub Copilot
{
provider: "github-copilot",
model: "claude-sonnet-4.5",
label: "copilot-claude-sonnet-4.5",
},
{
provider: "github-copilot",
model: "gpt-5.1-codex",
label: "copilot-gpt-5.1-codex",
},
{
provider: "github-copilot",
model: "gemini-3-flash-preview",
label: "copilot-gemini-3-flash-preview",
},
{
provider: "github-copilot",
model: "grok-code-fast-1",
label: "copilot-grok-code-fast-1",
},
// Amazon Bedrock
{
provider: "amazon-bedrock",
model: "global.anthropic.claude-sonnet-4-5-20250929-v1:0",
label: "bedrock-claude-sonnet-4-5",
},
// xAI
{ provider: "xai", model: "grok-code-fast-1", label: "xai-grok-code-fast-1" },
// Cerebras
{ provider: "cerebras", model: "zai-glm-4.7", label: "cerebras-zai-glm-4.7" },
// Groq
{
provider: "groq",
model: "openai/gpt-oss-120b",
label: "groq-gpt-oss-120b",
},
// Hugging Face
{
provider: "huggingface",
model: "moonshotai/Kimi-K2.5",
label: "huggingface-kimi-k2.5",
},
// Kimi For Coding
{
provider: "kimi-coding",
model: "kimi-k2-thinking",
label: "kimi-coding-k2-thinking",
},
// Mistral
{
provider: "mistral",
model: "devstral-medium-latest",
label: "mistral-devstral-medium",
},
// MiniMax
{ provider: "minimax", model: "MiniMax-M2.1", label: "minimax-m2.1" },
// OpenCode Zen
{ provider: "opencode", model: "big-pickle", label: "zen-big-pickle" },
{
provider: "opencode",
model: "claude-sonnet-4-5",
label: "zen-claude-sonnet-4-5",
},
{
provider: "opencode",
model: "gemini-3-flash",
label: "zen-gemini-3-flash",
},
{ provider: "opencode", model: "glm-4.7-free", label: "zen-glm-4.7-free" },
{ provider: "opencode", model: "gpt-5.2-codex", label: "zen-gpt-5.2-codex" },
{
provider: "opencode",
model: "minimax-m2.1-free",
label: "zen-minimax-m2.1-free",
},
// OpenCode Go
{ provider: "opencode-go", model: "kimi-k2.5", label: "go-kimi-k2.5" },
{ provider: "opencode-go", model: "minimax-m2.5", label: "go-minimax-m2.5" },
];
// Cached context structure
interface CachedContext {
label: string;
provider: string;
model: string;
api: Api;
messages: Message[];
generatedAt: string;
}
/**
* Get API key for provider - checks OAuth storage first, then env vars
*/
async function getApiKey(provider: string): Promise<string | undefined> {
const oauthKey = await resolveApiKey(provider);
if (oauthKey) return oauthKey;
return getEnvApiKey(provider);
}
/**
* Synchronous check for API key availability (env vars only, for skipIf)
*/
function hasApiKey(provider: string): boolean {
if (provider === "azure-openai-responses") {
return hasAzureOpenAICredentials();
}
return !!getEnvApiKey(provider);
}
/**
* Check if any provider has API keys available (for skipIf at describe level)
*/
function hasAnyApiKey(): boolean {
return PROVIDER_MODEL_PAIRS.some((pair) => hasApiKey(pair.provider));
}
function dumpFailurePayload(params: {
label: string;
error: string;
payload?: unknown;
messages: Message[];
}): void {
const filename = `/tmp/pi-handoff-${params.label}-${Date.now()}.json`;
const body = {
label: params.label,
error: params.error,
payload: params.payload,
messages: params.messages,
};
writeFileSync(filename, JSON.stringify(body, null, 2));
console.log(`Wrote failure payload to ${filename}`);
}
/**
* Generate a context from a provider/model pair.
* Makes a real API call to get authentic tool call IDs and thinking blocks.
*/
async function generateContext(
pair: ProviderModelPair,
apiKey: string,
): Promise<{ messages: Message[]; api: Api } | null> {
const baseModel = (
getModel as (p: string, m: string) => Model<Api> | undefined
)(pair.provider, pair.model);
if (!baseModel) {
console.log(` Model not found: ${pair.provider}/${pair.model}`);
return null;
}
const model: Model<Api> = pair.apiOverride
? { ...baseModel, api: pair.apiOverride }
: baseModel;
const userMessage: Message = {
role: "user",
content: "Please double the number 21 using the double_number tool.",
timestamp: Date.now(),
};
const supportsReasoning = model.reasoning === true;
let lastPayload: unknown;
let assistantResponse: AssistantMessage;
try {
assistantResponse = await completeSimple(
model,
{
systemPrompt:
"You are a helpful assistant. Use the provided tool to complete the task.",
messages: [userMessage],
tools: [testTool],
},
{
apiKey,
reasoning: supportsReasoning ? "high" : undefined,
onPayload: (payload) => {
lastPayload = payload;
},
},
);
} catch (error) {
const msg = error instanceof Error ? error.message : String(error);
console.log(` Initial request failed: ${msg}`);
dumpFailurePayload({
label: `${pair.label}-initial`,
error: msg,
payload: lastPayload,
messages: [userMessage],
});
return null;
}
if (assistantResponse.stopReason === "error") {
console.log(` Initial request error: ${assistantResponse.errorMessage}`);
dumpFailurePayload({
label: `${pair.label}-initial`,
error: assistantResponse.errorMessage || "Unknown error",
payload: lastPayload,
messages: [userMessage],
});
return null;
}
const toolCall = assistantResponse.content.find((c) => c.type === "toolCall");
if (!toolCall || toolCall.type !== "toolCall") {
console.log(
` No tool call in response (stopReason: ${assistantResponse.stopReason})`,
);
return {
messages: [userMessage, assistantResponse],
api: model.api,
};
}
console.log(` Tool call ID: ${toolCall.id}`);
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: [{ type: "text", text: "42" }],
isError: false,
timestamp: Date.now(),
};
let finalResponse: AssistantMessage;
const messagesForFinal = [userMessage, assistantResponse, toolResult];
try {
finalResponse = await completeSimple(
model,
{
systemPrompt: "You are a helpful assistant.",
messages: messagesForFinal,
tools: [testTool],
},
{
apiKey,
reasoning: supportsReasoning ? "high" : undefined,
onPayload: (payload) => {
lastPayload = payload;
},
},
);
} catch (error) {
const msg = error instanceof Error ? error.message : String(error);
console.log(` Final request failed: ${msg}`);
dumpFailurePayload({
label: `${pair.label}-final`,
error: msg,
payload: lastPayload,
messages: messagesForFinal,
});
return null;
}
if (finalResponse.stopReason === "error") {
console.log(` Final request error: ${finalResponse.errorMessage}`);
dumpFailurePayload({
label: `${pair.label}-final`,
error: finalResponse.errorMessage || "Unknown error",
payload: lastPayload,
messages: messagesForFinal,
});
return null;
}
return {
messages: [userMessage, assistantResponse, toolResult, finalResponse],
api: model.api,
};
}
describe.skipIf(!hasAnyApiKey())("Cross-Provider Handoff", () => {
let contexts: Record<string, CachedContext>;
let availablePairs: ProviderModelPair[];
beforeAll(async () => {
contexts = {};
availablePairs = [];
console.log("\n=== Generating Fixtures ===\n");
for (const pair of PROVIDER_MODEL_PAIRS) {
const apiKey = await getApiKey(pair.provider);
if (!apiKey) {
console.log(`[${pair.label}] Skipping - no auth for ${pair.provider}`);
continue;
}
console.log(`[${pair.label}] Generating fixture...`);
const result = await generateContext(pair, apiKey);
if (!result || result.messages.length < 4) {
console.log(`[${pair.label}] Failed to generate fixture, skipping`);
continue;
}
contexts[pair.label] = {
label: pair.label,
provider: pair.provider,
model: pair.model,
api: result.api,
messages: result.messages,
generatedAt: new Date().toISOString(),
};
availablePairs.push(pair);
console.log(
`[${pair.label}] Generated ${result.messages.length} messages`,
);
}
console.log(
`\n=== ${availablePairs.length}/${PROVIDER_MODEL_PAIRS.length} contexts available ===\n`,
);
}, 300000);
it.skipIf(!hasAnyApiKey())(
"should have at least 2 fixtures to test handoffs",
() => {
expect(Object.keys(contexts).length).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!hasAnyApiKey())(
"should handle cross-provider handoffs for each target",
async () => {
const contextLabels = Object.keys(contexts);
if (contextLabels.length < 2) {
console.log("Not enough fixtures for handoff test, skipping");
return;
}
console.log("\n=== Testing Cross-Provider Handoffs ===\n");
const results: { target: string; success: boolean; error?: string }[] =
[];
for (const targetPair of availablePairs) {
const apiKey = await getApiKey(targetPair.provider);
if (!apiKey) {
console.log(`[Target: ${targetPair.label}] Skipping - no auth`);
continue;
}
// Collect messages from ALL OTHER contexts
const otherMessages: Message[] = [];
for (const [label, ctx] of Object.entries(contexts)) {
if (label === targetPair.label) continue;
otherMessages.push(...ctx.messages);
}
if (otherMessages.length === 0) {
console.log(
`[Target: ${targetPair.label}] Skipping - no other contexts`,
);
continue;
}
const allMessages: Message[] = [
...otherMessages,
{
role: "user",
content:
"Great, thanks for all that help! Now just say 'Hello, handoff successful!' to confirm you received everything.",
timestamp: Date.now(),
},
];
const baseModel = (
getModel as (p: string, m: string) => Model<Api> | undefined
)(targetPair.provider, targetPair.model);
if (!baseModel) {
console.log(`[Target: ${targetPair.label}] Model not found`);
continue;
}
const model: Model<Api> = targetPair.apiOverride
? { ...baseModel, api: targetPair.apiOverride }
: baseModel;
const supportsReasoning = model.reasoning === true;
console.log(
`[Target: ${targetPair.label}] Testing with ${otherMessages.length} messages from other providers...`,
);
let lastPayload: unknown;
try {
const response = await completeSimple(
model,
{
systemPrompt: "You are a helpful assistant.",
messages: allMessages,
tools: [testTool],
},
{
apiKey,
reasoning: supportsReasoning ? "high" : undefined,
onPayload: (payload) => {
lastPayload = payload;
},
},
);
if (response.stopReason === "error") {
console.log(
`[Target: ${targetPair.label}] FAILED: ${response.errorMessage}`,
);
dumpFailurePayload({
label: targetPair.label,
error: response.errorMessage || "Unknown error",
payload: lastPayload,
messages: allMessages,
});
results.push({
target: targetPair.label,
success: false,
error: response.errorMessage,
});
} else {
const text = response.content
.filter((c) => c.type === "text")
.map((c) => c.text)
.join(" ");
const preview = text.slice(0, 100).replace(/\n/g, " ");
console.log(`[Target: ${targetPair.label}] SUCCESS: ${preview}...`);
results.push({ target: targetPair.label, success: true });
}
} catch (error) {
const msg = error instanceof Error ? error.message : String(error);
console.log(`[Target: ${targetPair.label}] EXCEPTION: ${msg}`);
dumpFailurePayload({
label: targetPair.label,
error: msg,
payload: lastPayload,
messages: allMessages,
});
results.push({
target: targetPair.label,
success: false,
error: msg,
});
}
}
console.log("\n=== Results Summary ===\n");
const successes = results.filter((r) => r.success);
const failures = results.filter((r) => !r.success);
console.log(`Passed: ${successes.length}/${results.length}`);
if (failures.length > 0) {
console.log("\nFailures:");
for (const f of failures) {
console.log(` - ${f.target}: ${f.error}`);
}
}
expect(failures.length).toBe(0);
},
600000,
);
});

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,115 @@
import { describe, expect, it, vi } from "vitest";
import { getModel } from "../src/models.js";
import type { Context } from "../src/types.js";
const mockState = vi.hoisted(() => ({
constructorOpts: undefined as Record<string, unknown> | undefined,
streamParams: undefined as Record<string, unknown> | undefined,
}));
vi.mock("@anthropic-ai/sdk", () => {
const fakeStream = {
async *[Symbol.asyncIterator]() {
yield {
type: "message_start",
message: {
usage: { input_tokens: 10, output_tokens: 0 },
},
};
yield {
type: "message_delta",
delta: { stop_reason: "end_turn" },
usage: { output_tokens: 5 },
};
},
finalMessage: async () => ({
usage: {
input_tokens: 10,
output_tokens: 5,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
}),
};
class FakeAnthropic {
constructor(opts: Record<string, unknown>) {
mockState.constructorOpts = opts;
}
messages = {
stream: (params: Record<string, unknown>) => {
mockState.streamParams = params;
return fakeStream;
},
};
}
return { default: FakeAnthropic };
});
describe("Copilot Claude via Anthropic Messages", () => {
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [{ role: "user", content: "Hello", timestamp: Date.now() }],
};
it("uses Bearer auth, Copilot headers, and valid Anthropic Messages payload", async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
expect(model.api).toBe("anthropic-messages");
const { streamAnthropic } = await import("../src/providers/anthropic.js");
const s = streamAnthropic(model, context, {
apiKey: "tid_copilot_session_test_token",
});
for await (const event of s) {
if (event.type === "error") break;
}
const opts = mockState.constructorOpts!;
expect(opts).toBeDefined();
// Auth: apiKey null, authToken for Bearer
expect(opts.apiKey).toBeNull();
expect(opts.authToken).toBe("tid_copilot_session_test_token");
const headers = opts.defaultHeaders as Record<string, string>;
// Copilot static headers from model.headers
expect(headers["User-Agent"]).toContain("GitHubCopilotChat");
expect(headers["Copilot-Integration-Id"]).toBe("vscode-chat");
// Dynamic headers
expect(headers["X-Initiator"]).toBe("user");
expect(headers["Openai-Intent"]).toBe("conversation-edits");
// No fine-grained-tool-streaming (Copilot doesn't support it)
const beta = headers["anthropic-beta"] ?? "";
expect(beta).not.toContain("fine-grained-tool-streaming");
// Payload is valid Anthropic Messages format
const params = mockState.streamParams!;
expect(params.model).toBe("claude-sonnet-4");
expect(params.stream).toBe(true);
expect(params.max_tokens).toBeGreaterThan(0);
expect(Array.isArray(params.messages)).toBe(true);
});
it("includes interleaved-thinking beta when reasoning is enabled", async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const { streamAnthropic } = await import("../src/providers/anthropic.js");
const s = streamAnthropic(model, context, {
apiKey: "tid_copilot_session_test_token",
interleavedThinking: true,
});
for await (const event of s) {
if (event.type === "error") break;
}
const headers = mockState.constructorOpts!.defaultHeaders as Record<
string,
string
>;
expect(headers["anthropic-beta"]).toContain(
"interleaved-thinking-2025-05-14",
);
});
});

View file

@ -0,0 +1,109 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { streamGoogleGeminiCli } from "../src/providers/google-gemini-cli.js";
import type { Context, Model } from "../src/types.js";
const originalFetch = global.fetch;
const apiKey = JSON.stringify({ token: "token", projectId: "project" });
const createSseResponse = () => {
const sse = `${[
`data: ${JSON.stringify({
response: {
candidates: [
{
content: { role: "model", parts: [{ text: "Hello" }] },
finishReason: "STOP",
},
],
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const stream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
return new Response(stream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
};
afterEach(() => {
global.fetch = originalFetch;
vi.restoreAllMocks();
});
describe("google-gemini-cli Claude thinking header", () => {
const context: Context = {
messages: [{ role: "user", content: "Say hello", timestamp: Date.now() }],
};
it("adds anthropic-beta for Claude thinking models", async () => {
const fetchMock = vi.fn(
async (_input: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
expect(headers.get("anthropic-beta")).toBe(
"interleaved-thinking-2025-05-14",
);
return createSseResponse();
},
);
global.fetch = fetchMock as typeof fetch;
const model: Model<"google-gemini-cli"> = {
id: "claude-opus-4-5-thinking",
name: "Claude Opus 4.5 Thinking",
api: "google-gemini-cli",
provider: "google-antigravity",
baseUrl: "https://cloudcode-pa.googleapis.com",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192,
};
const stream = streamGoogleGeminiCli(model, context, { apiKey });
for await (const _event of stream) {
// exhaust stream
}
await stream.result();
});
it("does not add anthropic-beta for Gemini models", async () => {
const fetchMock = vi.fn(
async (_input: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
expect(headers.has("anthropic-beta")).toBe(false);
return createSseResponse();
},
);
global.fetch = fetchMock as typeof fetch;
const model: Model<"google-gemini-cli"> = {
id: "gemini-2.5-flash",
name: "Gemini 2.5 Flash",
api: "google-gemini-cli",
provider: "google-gemini-cli",
baseUrl: "https://cloudcode-pa.googleapis.com",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192,
};
const stream = streamGoogleGeminiCli(model, context, { apiKey });
for await (const _event of stream) {
// exhaust stream
}
await stream.result();
});
});

View file

@ -0,0 +1,108 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { streamGoogleGeminiCli } from "../src/providers/google-gemini-cli.js";
import type { Context, Model } from "../src/types.js";
const originalFetch = global.fetch;
afterEach(() => {
global.fetch = originalFetch;
vi.restoreAllMocks();
});
describe("google-gemini-cli empty stream retry", () => {
it("retries empty SSE responses without duplicate start", async () => {
const emptyStream = new ReadableStream<Uint8Array>({
start(controller) {
controller.close();
},
});
const sse = `${[
`data: ${JSON.stringify({
response: {
candidates: [
{
content: { role: "model", parts: [{ text: "Hello" }] },
finishReason: "STOP",
},
],
usageMetadata: {
promptTokenCount: 1,
candidatesTokenCount: 1,
totalTokenCount: 2,
},
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const dataStream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
let callCount = 0;
const fetchMock = vi.fn(async () => {
callCount += 1;
if (callCount === 1) {
return new Response(emptyStream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
}
return new Response(dataStream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
});
global.fetch = fetchMock as typeof fetch;
const model: Model<"google-gemini-cli"> = {
id: "gemini-2.5-flash",
name: "Gemini 2.5 Flash",
api: "google-gemini-cli",
provider: "google-gemini-cli",
baseUrl: "https://cloudcode-pa.googleapis.com",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192,
};
const context: Context = {
messages: [{ role: "user", content: "Say hello", timestamp: Date.now() }],
};
const stream = streamGoogleGeminiCli(model, context, {
apiKey: JSON.stringify({ token: "token", projectId: "project" }),
});
let startCount = 0;
let doneCount = 0;
let text = "";
for await (const event of stream) {
if (event.type === "start") {
startCount += 1;
}
if (event.type === "done") {
doneCount += 1;
}
if (event.type === "text_delta") {
text += event.delta;
}
}
const result = await stream.result();
expect(text).toBe("Hello");
expect(result.stopReason).toBe("stop");
expect(startCount).toBe(1);
expect(doneCount).toBe(1);
expect(fetchMock).toHaveBeenCalledTimes(2);
});
});

View file

@ -0,0 +1,57 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { extractRetryDelay } from "../src/providers/google-gemini-cli.js";
describe("extractRetryDelay header parsing", () => {
afterEach(() => {
vi.useRealTimers();
});
it("prefers Retry-After seconds header", () => {
vi.useFakeTimers();
vi.setSystemTime(new Date("2025-01-01T00:00:00Z"));
const response = new Response("", { headers: { "Retry-After": "5" } });
const delay = extractRetryDelay("Please retry in 1s", response);
expect(delay).toBe(6000);
});
it("parses Retry-After HTTP date header", () => {
vi.useFakeTimers();
const now = new Date("2025-01-01T00:00:00Z");
vi.setSystemTime(now);
const retryAt = new Date(now.getTime() + 12000).toUTCString();
const response = new Response("", { headers: { "Retry-After": retryAt } });
const delay = extractRetryDelay("", response);
expect(delay).toBe(13000);
});
it("parses x-ratelimit-reset header", () => {
vi.useFakeTimers();
const now = new Date("2025-01-01T00:00:00Z");
vi.setSystemTime(now);
const resetAtMs = now.getTime() + 20000;
const resetSeconds = Math.floor(resetAtMs / 1000).toString();
const response = new Response("", {
headers: { "x-ratelimit-reset": resetSeconds },
});
const delay = extractRetryDelay("", response);
expect(delay).toBe(21000);
});
it("parses x-ratelimit-reset-after header", () => {
vi.useFakeTimers();
vi.setSystemTime(new Date("2025-01-01T00:00:00Z"));
const response = new Response("", {
headers: { "x-ratelimit-reset-after": "30" },
});
const delay = extractRetryDelay("", response);
expect(delay).toBe(31000);
});
});

View file

@ -0,0 +1,195 @@
import { describe, expect, it } from "vitest";
import { convertMessages } from "../src/providers/google-shared.js";
import type { Context, Model } from "../src/types.js";
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
function makeGemini3Model(
id = "gemini-3-pro-preview",
): Model<"google-generative-ai"> {
return {
id,
name: "Gemini 3 Pro Preview",
api: "google-generative-ai",
provider: "google",
baseUrl: "https://generativelanguage.googleapis.com",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192,
};
}
describe("google-shared convertMessages — Gemini 3 unsigned tool calls", () => {
it("uses skip_thought_signature_validator for unsigned tool calls on Gemini 3", () => {
const model = makeGemini3Model();
const now = Date.now();
const context: Context = {
messages: [
{ role: "user", content: "Hi", timestamp: now },
{
role: "assistant",
content: [
{
type: "toolCall",
id: "call_1",
name: "bash",
arguments: { command: "ls -la" },
// No thoughtSignature: simulates Claude via Antigravity.
},
],
api: "google-gemini-cli",
provider: "google-antigravity",
model: "claude-sonnet-4-20250514",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
stopReason: "stop",
timestamp: now,
},
],
};
const contents = convertMessages(model, context);
const modelTurn = contents.find((c) => c.role === "model");
expect(modelTurn).toBeTruthy();
// Should be a structured functionCall, NOT text fallback
const fcPart = modelTurn?.parts?.find((p) => p.functionCall !== undefined);
expect(fcPart).toBeTruthy();
expect(fcPart?.functionCall?.name).toBe("bash");
expect(fcPart?.functionCall?.args).toEqual({ command: "ls -la" });
expect(fcPart?.thoughtSignature).toBe(SKIP_THOUGHT_SIGNATURE);
// No text fallback should exist
const textParts =
modelTurn?.parts?.filter((p) => p.text !== undefined) ?? [];
const historicalText = textParts.filter((p) =>
p.text?.includes("Historical context"),
);
expect(historicalText).toHaveLength(0);
});
it("preserves valid thoughtSignature when present (same provider/model)", () => {
const model = makeGemini3Model();
const now = Date.now();
// Valid base64 signature (16 bytes = 24 chars base64)
const validSig = "AAAAAAAAAAAAAAAAAAAAAA==";
const context: Context = {
messages: [
{ role: "user", content: "Hi", timestamp: now },
{
role: "assistant",
content: [
{
type: "toolCall",
id: "call_1",
name: "bash",
arguments: { command: "echo hi" },
thoughtSignature: validSig,
},
],
api: "google-generative-ai",
provider: "google",
model: "gemini-3-pro-preview",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
stopReason: "stop",
timestamp: now,
},
],
};
const contents = convertMessages(model, context);
const modelTurn = contents.find((c) => c.role === "model");
const fcPart = modelTurn?.parts?.find((p) => p.functionCall !== undefined);
expect(fcPart).toBeTruthy();
expect(fcPart?.thoughtSignature).toBe(validSig);
});
it("does not add sentinel for non-Gemini-3 models", () => {
const model: Model<"google-generative-ai"> = {
id: "gemini-2.5-flash",
name: "Gemini 2.5 Flash",
api: "google-generative-ai",
provider: "google",
baseUrl: "https://generativelanguage.googleapis.com",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192,
};
const now = Date.now();
const context: Context = {
messages: [
{ role: "user", content: "Hi", timestamp: now },
{
role: "assistant",
content: [
{
type: "toolCall",
id: "call_1",
name: "bash",
arguments: { command: "ls" },
// No thoughtSignature
},
],
api: "google-gemini-cli",
provider: "google-antigravity",
model: "claude-sonnet-4-20250514",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
stopReason: "stop",
timestamp: now,
},
],
};
const contents = convertMessages(model, context);
const modelTurn = contents.find((c) => c.role === "model");
const fcPart = modelTurn?.parts?.find((p) => p.functionCall !== undefined);
expect(fcPart).toBeTruthy();
// No sentinel, no thoughtSignature at all
expect(fcPart?.thoughtSignature).toBeUndefined();
});
});

View file

@ -0,0 +1,56 @@
import { describe, expect, it } from "vitest";
import {
isThinkingPart,
retainThoughtSignature,
} from "../src/providers/google-shared.js";
describe("Google thinking detection (thoughtSignature)", () => {
it("treats part.thought === true as thinking", () => {
expect(isThinkingPart({ thought: true, thoughtSignature: undefined })).toBe(
true,
);
expect(
isThinkingPart({ thought: true, thoughtSignature: "opaque-signature" }),
).toBe(true);
});
it("does not treat thoughtSignature alone as thinking", () => {
// Per Google docs, thoughtSignature is for context replay and can appear on any part type.
// Only thought === true indicates thinking content.
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
expect(
isThinkingPart({
thought: undefined,
thoughtSignature: "opaque-signature",
}),
).toBe(false);
expect(
isThinkingPart({ thought: false, thoughtSignature: "opaque-signature" }),
).toBe(false);
});
it("does not treat empty/missing signatures as thinking if thought is not set", () => {
expect(
isThinkingPart({ thought: undefined, thoughtSignature: undefined }),
).toBe(false);
expect(isThinkingPart({ thought: false, thoughtSignature: "" })).toBe(
false,
);
});
it("preserves the existing signature when subsequent deltas omit thoughtSignature", () => {
const first = retainThoughtSignature(undefined, "sig-1");
expect(first).toBe("sig-1");
const second = retainThoughtSignature(first, undefined);
expect(second).toBe("sig-1");
const third = retainThoughtSignature(second, "");
expect(third).toBe("sig-1");
});
it("updates the signature when a new non-empty signature arrives", () => {
const updated = retainThoughtSignature("sig-1", "sig-2");
expect(updated).toBe("sig-2");
});
});

View file

@ -0,0 +1,107 @@
import { Type } from "@sinclair/typebox";
import { afterEach, describe, expect, it, vi } from "vitest";
import { streamGoogleGeminiCli } from "../src/providers/google-gemini-cli.js";
import type { Context, Model, ToolCall } from "../src/types.js";
const emptySchema = Type.Object({});
const originalFetch = global.fetch;
afterEach(() => {
global.fetch = originalFetch;
vi.restoreAllMocks();
});
describe("google providers tool call missing args", () => {
it("defaults arguments to empty object when provider omits args field", async () => {
// Simulate a tool call response where args is missing (no-arg tool)
const sse = `${[
`data: ${JSON.stringify({
response: {
candidates: [
{
content: {
role: "model",
parts: [
{
functionCall: {
name: "get_status",
// args intentionally omitted
},
},
],
},
finishReason: "STOP",
},
],
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 5,
totalTokenCount: 15,
},
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const dataStream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
const fetchMock = vi.fn(async () => {
return new Response(dataStream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
});
global.fetch = fetchMock as typeof fetch;
const model: Model<"google-gemini-cli"> = {
id: "gemini-2.5-flash",
name: "Gemini 2.5 Flash",
api: "google-gemini-cli",
provider: "google-gemini-cli",
baseUrl: "https://cloudcode-pa.googleapis.com",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192,
};
const context: Context = {
messages: [
{ role: "user", content: "Check status", timestamp: Date.now() },
],
tools: [
{
name: "get_status",
description: "Get current status",
parameters: emptySchema,
},
],
};
const stream = streamGoogleGeminiCli(model, context, {
apiKey: JSON.stringify({ token: "token", projectId: "project" }),
});
for await (const _ of stream) {
// consume stream
}
const result = await stream.result();
expect(result.stopReason).toBe("toolUse");
expect(result.content).toHaveLength(1);
const toolCall = result.content[0] as ToolCall;
expect(toolCall.type).toBe("toolCall");
expect(toolCall.name).toBe("get_status");
expect(toolCall.arguments).toEqual({});
});
});

View file

@ -0,0 +1,630 @@
import { readFileSync } from "node:fs";
import { join } from "node:path";
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import type {
Api,
Context,
Model,
Tool,
ToolResultMessage,
} from "../src/index.js";
import { complete, getModel } from "../src/index.js";
import type { StreamOptions } from "../src/types.js";
type StreamOptionsWithExtras = StreamOptions & Record<string, unknown>;
import {
hasAzureOpenAICredentials,
resolveAzureDeploymentName,
} from "./azure-utils.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
resolveApiKey("openai-codex"),
]);
const [
anthropicOAuthToken,
githubCopilotToken,
geminiCliToken,
antigravityToken,
openaiCodexToken,
] = oauthTokens;
/**
* Test that tool results containing only images work correctly across all providers.
* This verifies that:
* 1. Tool results can contain image content blocks
* 2. Providers correctly pass images from tool results to the LLM
* 3. The LLM can see and describe images returned by tools
*/
async function handleToolWithImageResult<TApi extends Api>(
model: Model<TApi>,
options?: StreamOptionsWithExtras,
) {
// Check if the model supports images
if (!model.input.includes("image")) {
console.log(
`Skipping tool image result test - model ${model.id} doesn't support images`,
);
return;
}
// Read the test image
const imagePath = join(__dirname, "data", "red-circle.png");
const imageBuffer = readFileSync(imagePath);
const base64Image = imageBuffer.toString("base64");
// Define a tool that returns only an image (no text)
const getImageSchema = Type.Object({});
const getImageTool: Tool<typeof getImageSchema> = {
name: "get_circle",
description: "Returns a circle image for visualization",
parameters: getImageSchema,
};
const context: Context = {
systemPrompt: "You are a helpful assistant that uses tools when asked.",
messages: [
{
role: "user",
content:
"Call the get_circle tool to get an image, and describe what you see, shapes, colors, etc.",
timestamp: Date.now(),
},
],
tools: [getImageTool],
};
// First request - LLM should call the tool
const firstResponse = await complete(model, context, options);
expect(firstResponse.stopReason).toBe("toolUse");
// Find the tool call
const toolCall = firstResponse.content.find((b) => b.type === "toolCall");
expect(toolCall).toBeTruthy();
if (!toolCall || toolCall.type !== "toolCall") {
throw new Error("Expected tool call");
}
expect(toolCall.name).toBe("get_circle");
// Add the tool call to context
context.messages.push(firstResponse);
// Create tool result with ONLY an image (no text)
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: [
{
type: "image",
data: base64Image,
mimeType: "image/png",
},
],
isError: false,
timestamp: Date.now(),
};
context.messages.push(toolResult);
// Second request - LLM should describe the image from the tool result
const secondResponse = await complete(model, context, options);
expect(secondResponse.stopReason).toBe("stop");
expect(secondResponse.errorMessage).toBeFalsy();
// Verify the LLM can see and describe the image
const textContent = secondResponse.content.find((b) => b.type === "text");
expect(textContent).toBeTruthy();
if (textContent && textContent.type === "text") {
const lowerContent = textContent.text.toLowerCase();
// Should mention red and circle since that's what the image shows
expect(lowerContent).toContain("red");
expect(lowerContent).toContain("circle");
}
}
/**
* Test that tool results containing both text and images work correctly across all providers.
* This verifies that:
* 1. Tool results can contain mixed content blocks (text + images)
* 2. Providers correctly pass both text and images from tool results to the LLM
* 3. The LLM can see both the text and images in tool results
*/
async function handleToolWithTextAndImageResult<TApi extends Api>(
model: Model<TApi>,
options?: StreamOptionsWithExtras,
) {
// Check if the model supports images
if (!model.input.includes("image")) {
console.log(
`Skipping tool text+image result test - model ${model.id} doesn't support images`,
);
return;
}
// Read the test image
const imagePath = join(__dirname, "data", "red-circle.png");
const imageBuffer = readFileSync(imagePath);
const base64Image = imageBuffer.toString("base64");
// Define a tool that returns both text and an image
const getImageSchema = Type.Object({});
const getImageTool: Tool<typeof getImageSchema> = {
name: "get_circle_with_description",
description: "Returns a circle image with a text description",
parameters: getImageSchema,
};
const context: Context = {
systemPrompt: "You are a helpful assistant that uses tools when asked.",
messages: [
{
role: "user",
content:
"Use the get_circle_with_description tool and tell me what you learned. Also say what color the shape is.",
timestamp: Date.now(),
},
],
tools: [getImageTool],
};
// First request - LLM should call the tool
const firstResponse = await complete(model, context, options);
expect(firstResponse.stopReason).toBe("toolUse");
// Find the tool call
const toolCall = firstResponse.content.find((b) => b.type === "toolCall");
expect(toolCall).toBeTruthy();
if (!toolCall || toolCall.type !== "toolCall") {
throw new Error("Expected tool call");
}
expect(toolCall.name).toBe("get_circle_with_description");
// Add the tool call to context
context.messages.push(firstResponse);
// Create tool result with BOTH text and image
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: [
{
type: "text",
text: "This is a geometric shape with specific properties: it has a diameter of 100 pixels.",
},
{
type: "image",
data: base64Image,
mimeType: "image/png",
},
],
isError: false,
timestamp: Date.now(),
};
context.messages.push(toolResult);
// Second request - LLM should describe both the text and image from the tool result
const secondResponse = await complete(model, context, options);
expect(secondResponse.stopReason).toBe("stop");
expect(secondResponse.errorMessage).toBeFalsy();
// Verify the LLM can see both text and image
const textContent = secondResponse.content.find((b) => b.type === "text");
expect(textContent).toBeTruthy();
if (textContent && textContent.type === "text") {
const lowerContent = textContent.text.toLowerCase();
// Should mention details from the text (diameter/pixels)
expect(lowerContent.match(/diameter|100|pixel/)).toBeTruthy();
// Should also mention the visual properties (red and circle)
expect(lowerContent).toContain("red");
expect(lowerContent).toContain("circle");
}
}
describe("Tool Results with Images", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)(
"Google Provider (gemini-2.5-flash)",
() => {
const llm = getModel("google", "gemini-2.5-flash");
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Completions Provider (gpt-4o-mini)",
() => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
);
void _compat;
const llm: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
};
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Responses Provider (gpt-5-mini)",
() => {
const llm = getModel("openai", "gpt-5-mini");
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!hasAzureOpenAICredentials())(
"Azure OpenAI Responses Provider (gpt-4o-mini)",
() => {
const llm = getModel("azure-openai-responses", "gpt-4o-mini");
const azureDeploymentName = resolveAzureDeploymentName(llm.id);
const azureOptions = azureDeploymentName ? { azureDeploymentName } : {};
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm, azureOptions);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm, azureOptions);
},
);
},
);
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
"Anthropic Provider (claude-haiku-4-5)",
() => {
const model = getModel("anthropic", "claude-haiku-4-5");
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(model);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(model);
},
);
},
);
describe.skipIf(!process.env.OPENROUTER_API_KEY)(
"OpenRouter Provider (glm-4.5v)",
() => {
const llm = getModel("openrouter", "z-ai/glm-4.5v");
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!process.env.MISTRAL_API_KEY)(
"Mistral Provider (pixtral-12b)",
() => {
const llm = getModel("mistral", "pixtral-12b");
it(
"should handle tool result with only image",
{ retry: 5, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 5, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!process.env.KIMI_API_KEY)(
"Kimi For Coding Provider (k2p5)",
() => {
const llm = getModel("kimi-coding", "k2p5");
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!process.env.AI_GATEWAY_API_KEY)(
"Vercel AI Gateway Provider (google/gemini-2.5-flash)",
() => {
const llm = getModel("vercel-ai-gateway", "google/gemini-2.5-flash");
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
describe.skipIf(!hasBedrockCredentials())(
"Amazon Bedrock Provider (claude-sonnet-4-5)",
() => {
const llm = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
it(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(llm);
},
);
it(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(llm);
},
);
},
);
// =========================================================================
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
// =========================================================================
describe("Anthropic OAuth Provider (claude-sonnet-4-5)", () => {
const model = getModel("anthropic", "claude-sonnet-4-5");
it.skipIf(!anthropicOAuthToken)(
"should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithImageResult(model, { apiKey: anthropicOAuthToken });
},
);
it.skipIf(!anthropicOAuthToken)(
"should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
await handleToolWithTextAndImageResult(model, {
apiKey: anthropicOAuthToken,
});
},
);
});
describe("GitHub Copilot Provider", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("github-copilot", "gpt-4o");
await handleToolWithImageResult(llm, { apiKey: githubCopilotToken });
},
);
it.skipIf(!githubCopilotToken)(
"gpt-4o - should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("github-copilot", "gpt-4o");
await handleToolWithTextAndImageResult(llm, {
apiKey: githubCopilotToken,
});
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("github-copilot", "claude-sonnet-4");
await handleToolWithImageResult(llm, { apiKey: githubCopilotToken });
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("github-copilot", "claude-sonnet-4");
await handleToolWithTextAndImageResult(llm, {
apiKey: githubCopilotToken,
});
},
);
});
describe("Google Gemini CLI Provider", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-gemini-cli", "gemini-2.5-flash");
await handleToolWithImageResult(llm, { apiKey: geminiCliToken });
},
);
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-gemini-cli", "gemini-2.5-flash");
await handleToolWithTextAndImageResult(llm, { apiKey: geminiCliToken });
},
);
});
describe("Google Antigravity Provider", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "gemini-3-flash");
await handleToolWithImageResult(llm, { apiKey: antigravityToken });
},
);
it.skipIf(!antigravityToken)(
"gemini-3-flash - should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "gemini-3-flash");
await handleToolWithTextAndImageResult(llm, {
apiKey: antigravityToken,
});
},
);
/** These two don't work, the model simply won't call the tool, works in pi
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "claude-sonnet-4-5");
await handleToolWithImageResult(llm, { apiKey: antigravityToken });
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "claude-sonnet-4-5");
await handleToolWithTextAndImageResult(llm, { apiKey: antigravityToken });
},
);**/
// Note: gpt-oss-120b-medium does not support images, so not tested here
});
describe("OpenAI Codex Provider", () => {
it.skipIf(!openaiCodexToken)(
"gpt-5.2-codex - should handle tool result with only image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("openai-codex", "gpt-5.2-codex");
await handleToolWithImageResult(llm, { apiKey: openaiCodexToken });
},
);
it.skipIf(!openaiCodexToken)(
"gpt-5.2-codex - should handle tool result with text and image",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("openai-codex", "gpt-5.2-codex");
await handleToolWithTextAndImageResult(llm, {
apiKey: openaiCodexToken,
});
},
);
});
});

View file

@ -0,0 +1,206 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { getEnvApiKey } from "../src/env-api-keys.js";
import { getModel } from "../src/models.js";
import { completeSimple } from "../src/stream.js";
import type {
Api,
Context,
Model,
StopReason,
Tool,
ToolCall,
ToolResultMessage,
} from "../src/types.js";
import { StringEnum } from "../src/utils/typebox-helpers.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
const calculatorSchema = Type.Object({
a: Type.Number({ description: "First number" }),
b: Type.Number({ description: "Second number" }),
operation: StringEnum(["add", "subtract", "multiply", "divide"], {
description: "The operation to perform.",
}),
});
const calculatorTool: Tool<typeof calculatorSchema> = {
name: "calculator",
description: "Perform basic arithmetic operations",
parameters: calculatorSchema,
};
type CalculatorOperation = "add" | "subtract" | "multiply" | "divide";
type CalculatorArguments = {
a: number;
b: number;
operation: CalculatorOperation;
};
function asCalculatorArguments(
args: ToolCall["arguments"],
): CalculatorArguments {
if (typeof args !== "object" || args === null) {
throw new Error("Tool arguments must be an object");
}
const value = args as Record<string, unknown>;
const operation = value.operation;
if (
typeof value.a !== "number" ||
typeof value.b !== "number" ||
(operation !== "add" &&
operation !== "subtract" &&
operation !== "multiply" &&
operation !== "divide")
) {
throw new Error("Invalid calculator arguments");
}
return { a: value.a, b: value.b, operation };
}
function evaluateCalculatorCall(toolCall: ToolCall): number {
const { a, b, operation } = asCalculatorArguments(toolCall.arguments);
switch (operation) {
case "add":
return a + b;
case "subtract":
return a - b;
case "multiply":
return a * b;
case "divide":
return a / b;
}
}
async function assertSecondToolCallWithInterleavedThinking<TApi extends Api>(
llm: Model<TApi>,
reasoning: "high" | "xhigh",
) {
const context: Context = {
systemPrompt: [
"You are a helpful assistant that must use tools for arithmetic.",
"Always think before every tool call, not just the first one.",
"Do not answer with plain text when a tool call is required.",
].join(" "),
messages: [
{
role: "user",
content: [
"Use calculator to calculate 328 * 29.",
"You must call the calculator tool exactly once.",
"Provide the final answer based on the best guess given the tool result, even if it seems unreliable.",
"Start by thinking about the steps you will take to solve the problem.",
].join(" "),
timestamp: Date.now(),
},
],
tools: [calculatorTool],
};
const firstResponse = await completeSimple(llm, context, { reasoning });
expect(firstResponse.stopReason, `Error: ${firstResponse.errorMessage}`).toBe(
"toolUse" satisfies StopReason,
);
expect(firstResponse.content.some((block) => block.type === "thinking")).toBe(
true,
);
expect(firstResponse.content.some((block) => block.type === "toolCall")).toBe(
true,
);
const firstToolCall = firstResponse.content.find(
(block) => block.type === "toolCall",
);
expect(firstToolCall?.type).toBe("toolCall");
if (!firstToolCall || firstToolCall.type !== "toolCall") {
throw new Error("Expected first response to include a tool call");
}
context.messages.push(firstResponse);
const correctAnswer = evaluateCalculatorCall(firstToolCall);
const firstToolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: firstToolCall.id,
toolName: firstToolCall.name,
content: [
{
type: "text",
text: `The answer is ${correctAnswer} or ${correctAnswer * 2}.`,
},
],
isError: false,
timestamp: Date.now(),
};
context.messages.push(firstToolResult);
const secondResponse = await completeSimple(llm, context, { reasoning });
expect(
secondResponse.stopReason,
`Error: ${secondResponse.errorMessage}`,
).toBe("stop" satisfies StopReason);
expect(
secondResponse.content.some((block) => block.type === "thinking"),
).toBe(true);
expect(secondResponse.content.some((block) => block.type === "text")).toBe(
true,
);
}
const hasAnthropicCredentials = !!getEnvApiKey("anthropic");
describe.skipIf(!hasBedrockCredentials())(
"Amazon Bedrock interleaved thinking",
() => {
it(
"should do interleaved thinking on Claude Opus 4.5",
{ retry: 3 },
async () => {
const llm = getModel(
"amazon-bedrock",
"global.anthropic.claude-opus-4-5-20251101-v1:0",
);
await assertSecondToolCallWithInterleavedThinking(llm, "high");
},
);
it(
"should do interleaved thinking on Claude Opus 4.6",
{ retry: 3 },
async () => {
const llm = getModel(
"amazon-bedrock",
"global.anthropic.claude-opus-4-6-v1",
);
await assertSecondToolCallWithInterleavedThinking(llm, "high");
},
);
},
);
describe.skipIf(!hasAnthropicCredentials)(
"Anthropic interleaved thinking",
() => {
it(
"should do interleaved thinking on Claude Opus 4.5",
{ retry: 3 },
async () => {
const llm = getModel("anthropic", "claude-opus-4-5");
await assertSecondToolCallWithInterleavedThinking(llm, "high");
},
);
it(
"should do interleaved thinking on Claude Opus 4.6",
{ retry: 3 },
async () => {
const llm = getModel("anthropic", "claude-opus-4-6");
await assertSecondToolCallWithInterleavedThinking(llm, "high");
},
);
},
);

103
packages/ai/test/oauth.ts Normal file
View file

@ -0,0 +1,103 @@
/**
* Test helper for resolving API keys from ~/.pi/agent/auth.json
*
* Supports both API key and OAuth credentials.
* OAuth tokens are automatically refreshed if expired and saved back to auth.json.
*/
import {
chmodSync,
existsSync,
mkdirSync,
readFileSync,
writeFileSync,
} from "fs";
import { homedir } from "os";
import { dirname, join } from "path";
import { getOAuthApiKey } from "../src/utils/oauth/index.js";
import type {
OAuthCredentials,
OAuthProvider,
} from "../src/utils/oauth/types.js";
const AUTH_PATH = join(homedir(), ".pi", "agent", "auth.json");
type ApiKeyCredential = {
type: "api_key";
key: string;
};
type OAuthCredentialEntry = {
type: "oauth";
} & OAuthCredentials;
type AuthCredential = ApiKeyCredential | OAuthCredentialEntry;
type AuthStorage = Record<string, AuthCredential>;
function loadAuthStorage(): AuthStorage {
if (!existsSync(AUTH_PATH)) {
return {};
}
try {
const content = readFileSync(AUTH_PATH, "utf-8");
return JSON.parse(content);
} catch {
return {};
}
}
function saveAuthStorage(storage: AuthStorage): void {
const configDir = dirname(AUTH_PATH);
if (!existsSync(configDir)) {
mkdirSync(configDir, { recursive: true, mode: 0o700 });
}
writeFileSync(AUTH_PATH, JSON.stringify(storage, null, 2), "utf-8");
chmodSync(AUTH_PATH, 0o600);
}
/**
* Resolve API key for a provider from ~/.pi/agent/auth.json
*
* For API key credentials, returns the key directly.
* For OAuth credentials, returns the access token (refreshing if expired and saving back).
*
* For google-gemini-cli and google-antigravity, returns JSON-encoded { token, projectId }
*/
export async function resolveApiKey(
provider: string,
): Promise<string | undefined> {
const storage = loadAuthStorage();
const entry = storage[provider];
if (!entry) return undefined;
if (entry.type === "api_key") {
return entry.key;
}
if (entry.type === "oauth") {
// Build OAuthCredentials record for getOAuthApiKey
const oauthCredentials: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(storage)) {
if (value.type === "oauth") {
const { type: _, ...creds } = value;
oauthCredentials[key] = creds;
}
}
const result = await getOAuthApiKey(
provider as OAuthProvider,
oauthCredentials,
);
if (!result) return undefined;
// Save refreshed credentials back to auth.json
storage[provider] = { type: "oauth", ...result.newCredentials };
saveAuthStorage(storage);
return result.apiKey;
}
return undefined;
}

View file

@ -0,0 +1,506 @@
import { mkdtempSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { afterEach, describe, expect, it, vi } from "vitest";
import { streamOpenAICodexResponses } from "../src/providers/openai-codex-responses.js";
import type { Context, Model } from "../src/types.js";
const originalFetch = global.fetch;
const originalAgentDir = process.env.PI_CODING_AGENT_DIR;
afterEach(() => {
global.fetch = originalFetch;
if (originalAgentDir === undefined) {
delete process.env.PI_CODING_AGENT_DIR;
} else {
process.env.PI_CODING_AGENT_DIR = originalAgentDir;
}
vi.restoreAllMocks();
});
describe("openai-codex streaming", () => {
it("streams SSE responses into AssistantMessageEventStream", async () => {
const tempDir = mkdtempSync(join(tmpdir(), "pi-codex-stream-"));
process.env.PI_CODING_AGENT_DIR = tempDir;
const payload = Buffer.from(
JSON.stringify({
"https://api.openai.com/auth": { chatgpt_account_id: "acc_test" },
}),
"utf8",
).toString("base64");
const token = `aaa.${payload}.bbb`;
const sse = `${[
`data: ${JSON.stringify({
type: "response.output_item.added",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "in_progress",
content: [],
},
})}`,
`data: ${JSON.stringify({ type: "response.content_part.added", part: { type: "output_text", text: "" } })}`,
`data: ${JSON.stringify({ type: "response.output_text.delta", delta: "Hello" })}`,
`data: ${JSON.stringify({
type: "response.output_item.done",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "completed",
content: [{ type: "output_text", text: "Hello" }],
},
})}`,
`data: ${JSON.stringify({
type: "response.completed",
response: {
status: "completed",
usage: {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
input_tokens_details: { cached_tokens: 0 },
},
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const stream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
const fetchMock = vi.fn(async (input: string | URL, init?: RequestInit) => {
const url = typeof input === "string" ? input : input.toString();
if (url === "https://api.github.com/repos/openai/codex/releases/latest") {
return new Response(JSON.stringify({ tag_name: "rust-v0.0.0" }), {
status: 200,
});
}
if (url.startsWith("https://raw.githubusercontent.com/openai/codex/")) {
return new Response("PROMPT", {
status: 200,
headers: { etag: '"etag"' },
});
}
if (url === "https://chatgpt.com/backend-api/codex/responses") {
const headers =
init?.headers instanceof Headers ? init.headers : undefined;
expect(headers?.get("Authorization")).toBe(`Bearer ${token}`);
expect(headers?.get("chatgpt-account-id")).toBe("acc_test");
expect(headers?.get("OpenAI-Beta")).toBe("responses=experimental");
expect(headers?.get("originator")).toBe("pi");
expect(headers?.get("accept")).toBe("text/event-stream");
expect(headers?.has("x-api-key")).toBe(false);
return new Response(stream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
}
return new Response("not found", { status: 404 });
});
global.fetch = fetchMock as typeof fetch;
const model: Model<"openai-codex-responses"> = {
id: "gpt-5.1-codex",
name: "GPT-5.1 Codex",
api: "openai-codex-responses",
provider: "openai-codex",
baseUrl: "https://chatgpt.com/backend-api",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 400000,
maxTokens: 128000,
};
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [{ role: "user", content: "Say hello", timestamp: Date.now() }],
};
const streamResult = streamOpenAICodexResponses(model, context, {
apiKey: token,
});
let sawTextDelta = false;
let sawDone = false;
for await (const event of streamResult) {
if (event.type === "text_delta") {
sawTextDelta = true;
}
if (event.type === "done") {
sawDone = true;
expect(event.message.content.find((c) => c.type === "text")?.text).toBe(
"Hello",
);
}
}
expect(sawTextDelta).toBe(true);
expect(sawDone).toBe(true);
});
it("sets conversation_id/session_id headers and prompt_cache_key when sessionId is provided", async () => {
const tempDir = mkdtempSync(join(tmpdir(), "pi-codex-stream-"));
process.env.PI_CODING_AGENT_DIR = tempDir;
const payload = Buffer.from(
JSON.stringify({
"https://api.openai.com/auth": { chatgpt_account_id: "acc_test" },
}),
"utf8",
).toString("base64");
const token = `aaa.${payload}.bbb`;
const sse = `${[
`data: ${JSON.stringify({
type: "response.output_item.added",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "in_progress",
content: [],
},
})}`,
`data: ${JSON.stringify({ type: "response.content_part.added", part: { type: "output_text", text: "" } })}`,
`data: ${JSON.stringify({ type: "response.output_text.delta", delta: "Hello" })}`,
`data: ${JSON.stringify({
type: "response.output_item.done",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "completed",
content: [{ type: "output_text", text: "Hello" }],
},
})}`,
`data: ${JSON.stringify({
type: "response.completed",
response: {
status: "completed",
usage: {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
input_tokens_details: { cached_tokens: 0 },
},
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const stream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
const sessionId = "test-session-123";
const fetchMock = vi.fn(async (input: string | URL, init?: RequestInit) => {
const url = typeof input === "string" ? input : input.toString();
if (url === "https://api.github.com/repos/openai/codex/releases/latest") {
return new Response(JSON.stringify({ tag_name: "rust-v0.0.0" }), {
status: 200,
});
}
if (url.startsWith("https://raw.githubusercontent.com/openai/codex/")) {
return new Response("PROMPT", {
status: 200,
headers: { etag: '"etag"' },
});
}
if (url === "https://chatgpt.com/backend-api/codex/responses") {
const headers =
init?.headers instanceof Headers ? init.headers : undefined;
// Verify sessionId is set in headers
expect(headers?.get("conversation_id")).toBe(sessionId);
expect(headers?.get("session_id")).toBe(sessionId);
// Verify sessionId is set in request body as prompt_cache_key
const body =
typeof init?.body === "string"
? (JSON.parse(init.body) as Record<string, unknown>)
: null;
expect(body?.prompt_cache_key).toBe(sessionId);
expect(body?.prompt_cache_retention).toBe("in-memory");
return new Response(stream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
}
return new Response("not found", { status: 404 });
});
global.fetch = fetchMock as typeof fetch;
const model: Model<"openai-codex-responses"> = {
id: "gpt-5.1-codex",
name: "GPT-5.1 Codex",
api: "openai-codex-responses",
provider: "openai-codex",
baseUrl: "https://chatgpt.com/backend-api",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 400000,
maxTokens: 128000,
};
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [{ role: "user", content: "Say hello", timestamp: Date.now() }],
};
const streamResult = streamOpenAICodexResponses(model, context, {
apiKey: token,
sessionId,
});
await streamResult.result();
});
it.each(["gpt-5.3-codex", "gpt-5.4"])(
"clamps %s minimal reasoning effort to low",
async (modelId) => {
const tempDir = mkdtempSync(join(tmpdir(), "pi-codex-stream-"));
process.env.PI_CODING_AGENT_DIR = tempDir;
const payload = Buffer.from(
JSON.stringify({
"https://api.openai.com/auth": { chatgpt_account_id: "acc_test" },
}),
"utf8",
).toString("base64");
const token = `aaa.${payload}.bbb`;
const sse = `${[
`data: ${JSON.stringify({
type: "response.output_item.added",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "in_progress",
content: [],
},
})}`,
`data: ${JSON.stringify({ type: "response.content_part.added", part: { type: "output_text", text: "" } })}`,
`data: ${JSON.stringify({ type: "response.output_text.delta", delta: "Hello" })}`,
`data: ${JSON.stringify({
type: "response.output_item.done",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "completed",
content: [{ type: "output_text", text: "Hello" }],
},
})}`,
`data: ${JSON.stringify({
type: "response.completed",
response: {
status: "completed",
usage: {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
input_tokens_details: { cached_tokens: 0 },
},
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const stream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
const fetchMock = vi.fn(
async (input: string | URL, init?: RequestInit) => {
const url = typeof input === "string" ? input : input.toString();
if (
url === "https://api.github.com/repos/openai/codex/releases/latest"
) {
return new Response(JSON.stringify({ tag_name: "rust-v0.0.0" }), {
status: 200,
});
}
if (
url.startsWith("https://raw.githubusercontent.com/openai/codex/")
) {
return new Response("PROMPT", {
status: 200,
headers: { etag: '"etag"' },
});
}
if (url === "https://chatgpt.com/backend-api/codex/responses") {
const body =
typeof init?.body === "string"
? (JSON.parse(init.body) as Record<string, unknown>)
: null;
expect(body?.reasoning).toEqual({ effort: "low", summary: "auto" });
return new Response(stream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
}
return new Response("not found", { status: 404 });
},
);
global.fetch = fetchMock as typeof fetch;
const model: Model<"openai-codex-responses"> = {
id: modelId,
name: modelId,
api: "openai-codex-responses",
provider: "openai-codex",
baseUrl: "https://chatgpt.com/backend-api",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 400000,
maxTokens: 128000,
};
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [
{ role: "user", content: "Say hello", timestamp: Date.now() },
],
};
const streamResult = streamOpenAICodexResponses(model, context, {
apiKey: token,
reasoningEffort: "minimal",
});
await streamResult.result();
},
);
it("does not set conversation_id/session_id headers when sessionId is not provided", async () => {
const tempDir = mkdtempSync(join(tmpdir(), "pi-codex-stream-"));
process.env.PI_CODING_AGENT_DIR = tempDir;
const payload = Buffer.from(
JSON.stringify({
"https://api.openai.com/auth": { chatgpt_account_id: "acc_test" },
}),
"utf8",
).toString("base64");
const token = `aaa.${payload}.bbb`;
const sse = `${[
`data: ${JSON.stringify({
type: "response.output_item.added",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "in_progress",
content: [],
},
})}`,
`data: ${JSON.stringify({ type: "response.content_part.added", part: { type: "output_text", text: "" } })}`,
`data: ${JSON.stringify({ type: "response.output_text.delta", delta: "Hello" })}`,
`data: ${JSON.stringify({
type: "response.output_item.done",
item: {
type: "message",
id: "msg_1",
role: "assistant",
status: "completed",
content: [{ type: "output_text", text: "Hello" }],
},
})}`,
`data: ${JSON.stringify({
type: "response.completed",
response: {
status: "completed",
usage: {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
input_tokens_details: { cached_tokens: 0 },
},
},
})}`,
].join("\n\n")}\n\n`;
const encoder = new TextEncoder();
const stream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(sse));
controller.close();
},
});
const fetchMock = vi.fn(async (input: string | URL, init?: RequestInit) => {
const url = typeof input === "string" ? input : input.toString();
if (url === "https://api.github.com/repos/openai/codex/releases/latest") {
return new Response(JSON.stringify({ tag_name: "rust-v0.0.0" }), {
status: 200,
});
}
if (url.startsWith("https://raw.githubusercontent.com/openai/codex/")) {
return new Response("PROMPT", {
status: 200,
headers: { etag: '"etag"' },
});
}
if (url === "https://chatgpt.com/backend-api/codex/responses") {
const headers =
init?.headers instanceof Headers ? init.headers : undefined;
// Verify headers are not set when sessionId is not provided
expect(headers?.has("conversation_id")).toBe(false);
expect(headers?.has("session_id")).toBe(false);
return new Response(stream, {
status: 200,
headers: { "content-type": "text/event-stream" },
});
}
return new Response("not found", { status: 404 });
});
global.fetch = fetchMock as typeof fetch;
const model: Model<"openai-codex-responses"> = {
id: "gpt-5.1-codex",
name: "GPT-5.1 Codex",
api: "openai-codex-responses",
provider: "openai-codex",
baseUrl: "https://chatgpt.com/backend-api",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 400000,
maxTokens: 128000,
};
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [{ role: "user", content: "Say hello", timestamp: Date.now() }],
};
// No sessionId provided
const streamResult = streamOpenAICodexResponses(model, context, {
apiKey: token,
});
await streamResult.result();
});
});

View file

@ -0,0 +1,193 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it, vi } from "vitest";
import { getModel } from "../src/models.js";
import { streamSimple } from "../src/stream.js";
import type { Tool } from "../src/types.js";
const mockState = vi.hoisted(() => ({ lastParams: undefined as unknown }));
vi.mock("openai", () => {
class FakeOpenAI {
chat = {
completions: {
create: async (params: unknown) => {
mockState.lastParams = params;
return {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: {}, finish_reason: "stop" }],
usage: {
prompt_tokens: 1,
completion_tokens: 1,
prompt_tokens_details: { cached_tokens: 0 },
completion_tokens_details: { reasoning_tokens: 0 },
},
};
},
};
},
},
};
}
return { default: FakeOpenAI };
});
describe("openai-completions tool_choice", () => {
it("forwards toolChoice from simple options to payload", async () => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
)!;
const model = { ...baseModel, api: "openai-completions" } as const;
const tools: Tool[] = [
{
name: "ping",
description: "Ping tool",
parameters: Type.Object({
ok: Type.Boolean(),
}),
},
];
let payload: unknown;
await streamSimple(
model,
{
messages: [
{
role: "user",
content: "Call ping with ok=true",
timestamp: Date.now(),
},
],
tools,
},
{
apiKey: "test",
toolChoice: "required",
onPayload: (params: unknown) => {
payload = params;
},
} as unknown as Parameters<typeof streamSimple>[2],
).result();
const params = (payload ?? mockState.lastParams) as {
tool_choice?: string;
tools?: unknown[];
};
expect(params.tool_choice).toBe("required");
expect(Array.isArray(params.tools)).toBe(true);
expect(params.tools?.length ?? 0).toBeGreaterThan(0);
});
it("omits strict when compat disables strict mode", async () => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
)!;
const model = {
...baseModel,
api: "openai-completions",
compat: { supportsStrictMode: false },
} as const;
const tools: Tool[] = [
{
name: "ping",
description: "Ping tool",
parameters: Type.Object({
ok: Type.Boolean(),
}),
},
];
let payload: unknown;
await streamSimple(
model,
{
messages: [
{
role: "user",
content: "Call ping with ok=true",
timestamp: Date.now(),
},
],
tools,
},
{
apiKey: "test",
onPayload: (params: unknown) => {
payload = params;
},
} as unknown as Parameters<typeof streamSimple>[2],
).result();
const params = (payload ?? mockState.lastParams) as {
tools?: Array<{ function?: Record<string, unknown> }>;
};
const tool = params.tools?.[0]?.function;
expect(tool).toBeTruthy();
expect(tool?.strict).toBeUndefined();
expect("strict" in (tool ?? {})).toBe(false);
});
it("maps groq qwen3 reasoning levels to default reasoning_effort", async () => {
const model = getModel("groq", "qwen/qwen3-32b")!;
let payload: unknown;
await streamSimple(
model,
{
messages: [
{
role: "user",
content: "Hi",
timestamp: Date.now(),
},
],
},
{
apiKey: "test",
reasoning: "medium",
onPayload: (params: unknown) => {
payload = params;
},
},
).result();
const params = (payload ?? mockState.lastParams) as {
reasoning_effort?: string;
};
expect(params.reasoning_effort).toBe("default");
});
it("keeps normal reasoning_effort for groq models without compat mapping", async () => {
const model = getModel("groq", "openai/gpt-oss-20b")!;
let payload: unknown;
await streamSimple(
model,
{
messages: [
{
role: "user",
content: "Hi",
timestamp: Date.now(),
},
],
},
{
apiKey: "test",
reasoning: "medium",
onPayload: (params: unknown) => {
payload = params;
},
},
).result();
const params = (payload ?? mockState.lastParams) as {
reasoning_effort?: string;
};
expect(params.reasoning_effort).toBe("medium");
});
});

View file

@ -0,0 +1,111 @@
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { convertMessages } from "../src/providers/openai-completions.js";
import type {
AssistantMessage,
Context,
Model,
OpenAICompletionsCompat,
ToolResultMessage,
Usage,
} from "../src/types.js";
const emptyUsage: Usage = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
const compat: Required<OpenAICompletionsCompat> = {
supportsStore: true,
supportsDeveloperRole: true,
supportsReasoningEffort: true,
reasoningEffortMap: {},
supportsUsageInStreaming: true,
maxTokensField: "max_completion_tokens",
requiresToolResultName: false,
requiresAssistantAfterToolResult: false,
requiresThinkingAsText: false,
thinkingFormat: "openai",
openRouterRouting: {},
vercelGatewayRouting: {},
supportsStrictMode: true,
};
function buildToolResult(
toolCallId: string,
timestamp: number,
): ToolResultMessage {
return {
role: "toolResult",
toolCallId,
toolName: "read",
content: [
{ type: "text", text: "Read image file [image/png]" },
{ type: "image", data: "ZmFrZQ==", mimeType: "image/png" },
],
isError: false,
timestamp,
};
}
describe("openai-completions convertMessages", () => {
it("batches tool-result images after consecutive tool results", () => {
const baseModel = getModel("openai", "gpt-4o-mini");
const model: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
input: ["text", "image"],
};
const now = Date.now();
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{
type: "toolCall",
id: "tool-1",
name: "read",
arguments: { path: "img-1.png" },
},
{
type: "toolCall",
id: "tool-2",
name: "read",
arguments: { path: "img-2.png" },
},
],
api: model.api,
provider: model.provider,
model: model.id,
usage: emptyUsage,
stopReason: "toolUse",
timestamp: now,
};
const context: Context = {
messages: [
{ role: "user", content: "Read the images", timestamp: now - 2 },
assistantMessage,
buildToolResult("tool-1", now + 1),
buildToolResult("tool-2", now + 2),
],
};
const messages = convertMessages(model, context, compat);
const roles = messages.map((message) => message.role);
expect(roles).toEqual(["user", "assistant", "tool", "tool", "user"]);
const imageMessage = messages[messages.length - 1];
expect(imageMessage.role).toBe("user");
expect(Array.isArray(imageMessage.content)).toBe(true);
const imageParts = (
imageMessage.content as Array<{ type?: string }>
).filter((part) => part?.type === "image_url");
expect(imageParts.length).toBe(2);
});
});

View file

@ -0,0 +1,326 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { complete, getEnvApiKey } from "../src/stream.js";
import type {
AssistantMessage,
Context,
Message,
Tool,
ToolCall,
} from "../src/types.js";
const testToolSchema = Type.Object({
value: Type.Number({ description: "A number to double" }),
});
const testTool: Tool<typeof testToolSchema> = {
name: "double_number",
description: "Doubles a number and returns the result",
parameters: testToolSchema,
};
describe.skipIf(!process.env.OPENAI_API_KEY || !process.env.ANTHROPIC_API_KEY)(
"OpenAI Responses reasoning replay e2e",
() => {
it(
"skips reasoning-only history after an aborted turn",
{ retry: 2 },
async () => {
const model = getModel("openai", "gpt-5-mini");
const apiKey = getEnvApiKey("openai");
if (!apiKey) {
throw new Error("Missing OPENAI_API_KEY");
}
const userMessage: Message = {
role: "user",
content: "Use the double_number tool to double 21.",
timestamp: Date.now(),
};
const assistantResponse = await complete(
model,
{
systemPrompt: "You are a helpful assistant. Use the tool.",
messages: [userMessage],
tools: [testTool],
},
{
apiKey,
reasoningEffort: "high",
},
);
const thinkingBlock = assistantResponse.content.find(
(block) => block.type === "thinking" && block.thinkingSignature,
);
if (!thinkingBlock || thinkingBlock.type !== "thinking") {
throw new Error("Missing thinking signature from OpenAI Responses");
}
const corruptedAssistant: AssistantMessage = {
...assistantResponse,
content: [thinkingBlock],
stopReason: "aborted",
};
const followUp: Message = {
role: "user",
content: "Say hello to confirm you can continue.",
timestamp: Date.now(),
};
const context: Context = {
systemPrompt: "You are a helpful assistant.",
messages: [userMessage, corruptedAssistant, followUp],
tools: [testTool],
};
const response = await complete(model, context, {
apiKey,
reasoningEffort: "high",
});
// The key assertion: no 400 error from orphaned reasoning item
expect(response.stopReason, `Error: ${response.errorMessage}`).not.toBe(
"error",
);
expect(response.errorMessage).toBeFalsy();
// Model should respond (text or tool call)
expect(response.content.length).toBeGreaterThan(0);
},
);
it(
"handles same-provider different-model handoff with tool calls",
{ retry: 2 },
async () => {
// This tests the scenario where:
// 1. Model A (gpt-5-mini) generates reasoning + function_call
// 2. User switches to Model B (gpt-5.2-codex) - same provider, different model
// 3. transform-messages: isSameModel=false, thinking converted to text
// 4. But tool call ID still has OpenAI pairing history (fc_xxx paired with rs_xxx)
// 5. Without fix: OpenAI returns 400 "function_call without required reasoning item"
// 6. With fix: tool calls/results converted to text, conversation continues
const modelA = getModel("openai", "gpt-5-mini");
const modelB = getModel("openai", "gpt-5.2-codex");
const apiKey = getEnvApiKey("openai");
if (!apiKey) {
throw new Error("Missing OPENAI_API_KEY");
}
const userMessage: Message = {
role: "user",
content: "Use the double_number tool to double 21.",
timestamp: Date.now(),
};
// Get a real response from Model A with reasoning + tool call
const assistantResponse = await complete(
modelA,
{
systemPrompt:
"You are a helpful assistant. Always use the tool when asked.",
messages: [userMessage],
tools: [testTool],
},
{
apiKey,
reasoningEffort: "high",
},
);
const toolCallBlock = assistantResponse.content.find(
(block) => block.type === "toolCall",
) as ToolCall | undefined;
if (!toolCallBlock) {
throw new Error(
"Missing tool call from OpenAI Responses - model did not use the tool",
);
}
// Provide a tool result
const toolResult: Message = {
role: "toolResult",
toolCallId: toolCallBlock.id,
toolName: toolCallBlock.name,
content: [{ type: "text", text: "42" }],
isError: false,
timestamp: Date.now(),
};
const followUp: Message = {
role: "user",
content: "What was the result? Answer with just the number.",
timestamp: Date.now(),
};
// Now continue with Model B (different model, same provider)
const context: Context = {
systemPrompt: "You are a helpful assistant. Answer concisely.",
messages: [userMessage, assistantResponse, toolResult, followUp],
tools: [testTool],
};
let capturedPayload: any = null;
const response = await complete(modelB, context, {
apiKey,
reasoningEffort: "high",
onPayload: (payload) => {
capturedPayload = payload;
},
});
// The key assertion: no 400 error from orphaned function_call
expect(response.stopReason, `Error: ${response.errorMessage}`).not.toBe(
"error",
);
expect(response.errorMessage).toBeFalsy();
expect(response.content.length).toBeGreaterThan(0);
// Log what was sent for debugging
const input = capturedPayload?.input as any[];
const functionCalls =
input?.filter((item: any) => item.type === "function_call") || [];
const reasoningItems =
input?.filter((item: any) => item.type === "reasoning") || [];
console.log("Payload sent to API:");
console.log("- function_calls:", functionCalls.length);
console.log("- reasoning items:", reasoningItems.length);
console.log("- full input:", JSON.stringify(input, null, 2));
// Verify the model understood the context
const responseText = response.content
.filter((b) => b.type === "text")
.map((b) => (b as any).text)
.join("");
expect(responseText).toContain("42");
},
);
it(
"handles cross-provider handoff from Anthropic to OpenAI Codex",
{ retry: 2 },
async () => {
// This tests cross-provider handoff:
// 1. Anthropic model generates thinking + function_call (toolu_xxx ID)
// 2. User switches to OpenAI Codex
// 3. transform-messages: isSameModel=false, thinking converted to text
// 4. Tool call ID is Anthropic format (toolu_xxx), no OpenAI pairing history
// 5. Should work because foreign IDs have no pairing expectation
const anthropicModel = getModel("anthropic", "claude-sonnet-4-5");
const codexModel = getModel("openai", "gpt-5.2-codex");
const anthropicApiKey = getEnvApiKey("anthropic");
const openaiApiKey = getEnvApiKey("openai");
if (!anthropicApiKey || !openaiApiKey) {
throw new Error("Missing API keys");
}
const userMessage: Message = {
role: "user",
content: "Use the double_number tool to double 21.",
timestamp: Date.now(),
};
// Get a real response from Anthropic with thinking + tool call
const assistantResponse = await complete(
anthropicModel,
{
systemPrompt:
"You are a helpful assistant. Always use the tool when asked.",
messages: [userMessage],
tools: [testTool],
},
{
apiKey: anthropicApiKey,
thinkingEnabled: true,
thinkingBudgetTokens: 5000,
},
);
const toolCallBlock = assistantResponse.content.find(
(block) => block.type === "toolCall",
) as ToolCall | undefined;
if (!toolCallBlock) {
throw new Error(
"Missing tool call from Anthropic - model did not use the tool",
);
}
console.log("Anthropic tool call ID:", toolCallBlock.id);
// Provide a tool result
const toolResult: Message = {
role: "toolResult",
toolCallId: toolCallBlock.id,
toolName: toolCallBlock.name,
content: [{ type: "text", text: "42" }],
isError: false,
timestamp: Date.now(),
};
const followUp: Message = {
role: "user",
content: "What was the result? Answer with just the number.",
timestamp: Date.now(),
};
// Now continue with Codex (different provider)
const context: Context = {
systemPrompt: "You are a helpful assistant. Answer concisely.",
messages: [userMessage, assistantResponse, toolResult, followUp],
tools: [testTool],
};
let capturedPayload: any = null;
const response = await complete(codexModel, context, {
apiKey: openaiApiKey,
reasoningEffort: "high",
onPayload: (payload) => {
capturedPayload = payload;
},
});
// Log what was sent
const input = capturedPayload?.input as any[];
const functionCalls =
input?.filter((item: any) => item.type === "function_call") || [];
const reasoningItems =
input?.filter((item: any) => item.type === "reasoning") || [];
console.log("Payload sent to Codex:");
console.log("- function_calls:", functionCalls.length);
console.log("- reasoning items:", reasoningItems.length);
if (functionCalls.length > 0) {
console.log(
"- function_call IDs:",
functionCalls.map((fc: any) => fc.id),
);
}
// The key assertion: no 400 error
expect(response.stopReason, `Error: ${response.errorMessage}`).not.toBe(
"error",
);
expect(response.errorMessage).toBeFalsy();
expect(response.content.length).toBeGreaterThan(0);
// Verify the model understood the context
const responseText = response.content
.filter((b) => b.type === "text")
.map((b) => (b as any).text)
.join("");
expect(responseText).toContain("42");
},
);
},
);

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,28 @@
import { describe, expect, it } from "vitest";
import { getModel, supportsXhigh } from "../src/models.js";
describe("supportsXhigh", () => {
it("returns true for Anthropic Opus 4.6 on anthropic-messages API", () => {
const model = getModel("anthropic", "claude-opus-4-6");
expect(model).toBeDefined();
expect(supportsXhigh(model!)).toBe(true);
});
it("returns false for non-Opus Anthropic models", () => {
const model = getModel("anthropic", "claude-sonnet-4-5");
expect(model).toBeDefined();
expect(supportsXhigh(model!)).toBe(false);
});
it("returns true for GPT-5.4 models", () => {
const model = getModel("openai-codex", "gpt-5.4");
expect(model).toBeDefined();
expect(supportsXhigh(model!)).toBe(true);
});
it("returns false for OpenRouter Opus 4.6 (openai-completions API)", () => {
const model = getModel("openrouter", "anthropic/claude-opus-4.6");
expect(model).toBeDefined();
expect(supportsXhigh(model!)).toBe(false);
});
});

View file

@ -0,0 +1,397 @@
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { stream } from "../src/stream.js";
import type { Api, Context, Model, StreamOptions } from "../src/types.js";
type StreamOptionsWithExtras = StreamOptions & Record<string, unknown>;
import {
hasAzureOpenAICredentials,
resolveAzureDeploymentName,
} from "./azure-utils.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
resolveApiKey("openai-codex"),
]);
const [
anthropicOAuthToken,
githubCopilotToken,
geminiCliToken,
antigravityToken,
openaiCodexToken,
] = oauthTokens;
async function testTokensOnAbort<TApi extends Api>(
llm: Model<TApi>,
options: StreamOptionsWithExtras = {},
) {
const context: Context = {
messages: [
{
role: "user",
content:
"Write a long poem with 20 stanzas about the beauty of nature.",
timestamp: Date.now(),
},
],
systemPrompt: "You are a helpful assistant.",
};
const controller = new AbortController();
const response = stream(llm, context, {
...options,
signal: controller.signal,
});
let abortFired = false;
let text = "";
for await (const event of response) {
if (
!abortFired &&
(event.type === "text_delta" || event.type === "thinking_delta")
) {
text += event.delta;
if (text.length >= 1000) {
abortFired = true;
controller.abort();
}
}
}
const msg = await response.result();
expect(msg.stopReason).toBe("aborted");
// OpenAI providers, OpenAI Codex, Gemini CLI, zai, Amazon Bedrock, and the GPT-OSS model on Antigravity only send usage in the final chunk,
// so when aborted they have no token stats. Anthropic and Google send usage information early in the stream.
// MiniMax reports input tokens but not output tokens when aborted.
if (
llm.api === "openai-completions" ||
llm.api === "mistral-conversations" ||
llm.api === "openai-responses" ||
llm.api === "azure-openai-responses" ||
llm.api === "openai-codex-responses" ||
llm.provider === "google-gemini-cli" ||
llm.provider === "zai" ||
llm.provider === "amazon-bedrock" ||
llm.provider === "vercel-ai-gateway" ||
(llm.provider === "google-antigravity" && llm.id.includes("gpt-oss"))
) {
expect(msg.usage.input).toBe(0);
expect(msg.usage.output).toBe(0);
} else if (llm.provider === "minimax") {
// MiniMax reports input tokens early but output tokens only in final chunk
expect(msg.usage.input).toBeGreaterThan(0);
expect(msg.usage.output).toBe(0);
} else {
expect(msg.usage.input).toBeGreaterThan(0);
expect(msg.usage.output).toBeGreaterThan(0);
// Some providers (Antigravity, Copilot) have zero cost rates
if (llm.cost.input > 0) {
expect(msg.usage.cost.input).toBeGreaterThan(0);
expect(msg.usage.cost.total).toBeGreaterThan(0);
}
}
}
describe("Token Statistics on Abort", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider", () => {
const llm = getModel("google", "gemini-2.5-flash");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm, { thinking: { enabled: true } });
},
);
});
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Completions Provider",
() => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
)!;
void _compat;
const llm: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
};
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
},
);
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Responses Provider",
() => {
const llm = getModel("openai", "gpt-5-mini");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
},
);
describe.skipIf(!hasAzureOpenAICredentials())(
"Azure OpenAI Responses Provider",
() => {
const llm = getModel("azure-openai-responses", "gpt-4o-mini");
const azureDeploymentName = resolveAzureDeploymentName(llm.id);
const azureOptions = azureDeploymentName ? { azureDeploymentName } : {};
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm, azureOptions);
},
);
},
);
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider", () => {
const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider", () => {
const llm = getModel("xai", "grok-3-fast");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider", () => {
const llm = getModel("groq", "openai/gpt-oss-20b");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider", () => {
const llm = getModel("cerebras", "gpt-oss-120b");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.HF_TOKEN)("Hugging Face Provider", () => {
const llm = getModel("huggingface", "moonshotai/Kimi-K2.5");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider", () => {
const llm = getModel("zai", "glm-4.5-flash");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider", () => {
const llm = getModel("mistral", "devstral-medium-latest");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.MINIMAX_API_KEY)("MiniMax Provider", () => {
const llm = getModel("minimax", "MiniMax-M2.1");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.KIMI_API_KEY)("Kimi For Coding Provider", () => {
const llm = getModel("kimi-coding", "kimi-k2-thinking");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
describe.skipIf(!process.env.AI_GATEWAY_API_KEY)(
"Vercel AI Gateway Provider",
() => {
const llm = getModel("vercel-ai-gateway", "google/gemini-2.5-flash");
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
},
);
// =========================================================================
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
// =========================================================================
describe("Anthropic OAuth Provider", () => {
const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
it.skipIf(!anthropicOAuthToken)(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm, { apiKey: anthropicOAuthToken });
},
);
});
describe("GitHub Copilot Provider", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("github-copilot", "gpt-4o");
await testTokensOnAbort(llm, { apiKey: githubCopilotToken });
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("github-copilot", "claude-sonnet-4");
await testTokensOnAbort(llm, { apiKey: githubCopilotToken });
},
);
});
describe("Google Gemini CLI Provider", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-gemini-cli", "gemini-2.5-flash");
await testTokensOnAbort(llm, { apiKey: geminiCliToken });
},
);
});
describe("Google Antigravity Provider", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "gemini-3-flash");
await testTokensOnAbort(llm, { apiKey: antigravityToken });
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "claude-sonnet-4-5");
await testTokensOnAbort(llm, { apiKey: antigravityToken });
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("google-antigravity", "gpt-oss-120b-medium");
await testTokensOnAbort(llm, { apiKey: antigravityToken });
},
);
});
describe("OpenAI Codex Provider", () => {
it.skipIf(!openaiCodexToken)(
"gpt-5.2-codex - should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
const llm = getModel("openai-codex", "gpt-5.2-codex");
await testTokensOnAbort(llm, { apiKey: openaiCodexToken });
},
);
});
describe.skipIf(!hasBedrockCredentials())("Amazon Bedrock Provider", () => {
const llm = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
it(
"should include token stats when aborted mid-stream",
{ retry: 3, timeout: 30000 },
async () => {
await testTokensOnAbort(llm);
},
);
});
});

View file

@ -0,0 +1,320 @@
/**
* Tool Call ID Normalization Tests
*
* Tests that tool call IDs from OpenAI Responses API (github-copilot, openai-codex, opencode)
* are properly normalized when sent to other providers.
*
* OpenAI Responses API generates IDs in format: {call_id}|{id}
* where {id} can be 400+ chars with special characters (+, /, =).
*
* Regression test for: https://github.com/badlogic/pi-mono/issues/1022
*/
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { completeSimple, getEnvApiKey } from "../src/stream.js";
import type {
AssistantMessage,
Message,
Tool,
ToolResultMessage,
} from "../src/types.js";
import { resolveApiKey } from "./oauth.js";
// Resolve API keys
const copilotToken = await resolveApiKey("github-copilot");
const openrouterKey = getEnvApiKey("openrouter");
const codexToken = await resolveApiKey("openai-codex");
// Simple echo tool for testing
const echoToolSchema = Type.Object({
message: Type.String({ description: "Message to echo back" }),
});
const echoTool: Tool<typeof echoToolSchema> = {
name: "echo",
description: "Echoes the message back",
parameters: echoToolSchema,
};
/**
* Test 1: Live cross-provider handoff
*
* 1. Use github-copilot gpt-5.2-codex to generate a tool call
* 2. Switch to openrouter openai/gpt-5.2-codex and complete
* 3. Switch to openai-codex gpt-5.2-codex and complete
*
* Both should succeed without "call_id too long" errors.
*/
describe("Tool Call ID Normalization - Live Handoff", () => {
it.skipIf(!copilotToken || !openrouterKey)(
"github-copilot -> openrouter should normalize pipe-separated IDs",
async () => {
const copilotModel = getModel("github-copilot", "gpt-5.2-codex");
const openrouterModel = getModel("openrouter", "openai/gpt-5.2-codex");
// Step 1: Generate tool call with github-copilot
const userMessage: Message = {
role: "user",
content: "Use the echo tool to echo 'hello world'",
timestamp: Date.now(),
};
const assistantResponse = await completeSimple(
copilotModel,
{
systemPrompt:
"You are a helpful assistant. Use the echo tool when asked.",
messages: [userMessage],
tools: [echoTool],
},
{ apiKey: copilotToken },
);
expect(
assistantResponse.stopReason,
`Copilot error: ${assistantResponse.errorMessage}`,
).toBe("toolUse");
const toolCall = assistantResponse.content.find(
(c) => c.type === "toolCall",
);
expect(toolCall).toBeDefined();
expect(toolCall!.type).toBe("toolCall");
// Verify it's a pipe-separated ID (OpenAI Responses format)
if (toolCall?.type === "toolCall") {
expect(toolCall.id).toContain("|");
console.log(
`Tool call ID from github-copilot: ${toolCall.id.slice(0, 80)}...`,
);
}
// Create tool result
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: (toolCall as any).id,
toolName: "echo",
content: [{ type: "text", text: "hello world" }],
isError: false,
timestamp: Date.now(),
};
// Step 2: Complete with openrouter (uses openai-completions API)
const openrouterResponse = await completeSimple(
openrouterModel,
{
systemPrompt: "You are a helpful assistant.",
messages: [
userMessage,
assistantResponse,
toolResult,
{ role: "user", content: "Say hi", timestamp: Date.now() },
],
tools: [echoTool],
},
{ apiKey: openrouterKey },
);
// Should NOT fail with "call_id too long" error
expect(
openrouterResponse.stopReason,
`OpenRouter error: ${openrouterResponse.errorMessage}`,
).not.toBe("error");
expect(openrouterResponse.errorMessage).toBeUndefined();
},
60000,
);
it.skipIf(!copilotToken || !codexToken)(
"github-copilot -> openai-codex should normalize pipe-separated IDs",
async () => {
const copilotModel = getModel("github-copilot", "gpt-5.2-codex");
const codexModel = getModel("openai-codex", "gpt-5.2-codex");
// Step 1: Generate tool call with github-copilot
const userMessage: Message = {
role: "user",
content: "Use the echo tool to echo 'test message'",
timestamp: Date.now(),
};
const assistantResponse = await completeSimple(
copilotModel,
{
systemPrompt:
"You are a helpful assistant. Use the echo tool when asked.",
messages: [userMessage],
tools: [echoTool],
},
{ apiKey: copilotToken },
);
expect(
assistantResponse.stopReason,
`Copilot error: ${assistantResponse.errorMessage}`,
).toBe("toolUse");
const toolCall = assistantResponse.content.find(
(c) => c.type === "toolCall",
);
expect(toolCall).toBeDefined();
// Create tool result
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: (toolCall as any).id,
toolName: "echo",
content: [{ type: "text", text: "test message" }],
isError: false,
timestamp: Date.now(),
};
// Step 2: Complete with openai-codex (uses openai-codex-responses API)
const codexResponse = await completeSimple(
codexModel,
{
systemPrompt: "You are a helpful assistant.",
messages: [
userMessage,
assistantResponse,
toolResult,
{ role: "user", content: "Say hi", timestamp: Date.now() },
],
tools: [echoTool],
},
{ apiKey: codexToken },
);
// Should NOT fail with ID validation error
expect(
codexResponse.stopReason,
`Codex error: ${codexResponse.errorMessage}`,
).not.toBe("error");
expect(codexResponse.errorMessage).toBeUndefined();
},
60000,
);
});
/**
* Test 2: Prefilled context with exact failing IDs from issue #1022
*
* Uses the exact tool call ID format that caused the error:
* "call_xxx|very_long_base64_with_special_chars+/="
*/
describe("Tool Call ID Normalization - Prefilled Context", () => {
// Exact tool call ID from issue #1022 JSONL
const FAILING_TOOL_CALL_ID =
"call_pAYbIr76hXIjncD9UE4eGfnS|t5nnb2qYMFWGSsr13fhCd1CaCu3t3qONEPuOudu4HSVEtA8YJSL6FAZUxvoOoD792VIJWl91g87EdqsCWp9krVsdBysQoDaf9lMCLb8BS4EYi4gQd5kBQBYLlgD71PYwvf+TbMD9J9/5OMD42oxSRj8H+vRf78/l2Xla33LWz4nOgsddBlbvabICRs8GHt5C9PK5keFtzyi3lsyVKNlfduK3iphsZqs4MLv4zyGJnvZo/+QzShyk5xnMSQX/f98+aEoNflEApCdEOXipipgeiNWnpFSHbcwmMkZoJhURNu+JEz3xCh1mrXeYoN5o+trLL3IXJacSsLYXDrYTipZZbJFRPAucgbnjYBC+/ZzJOfkwCs+Gkw7EoZR7ZQgJ8ma+9586n4tT4cI8DEhBSZsWMjrCt8dxKg==";
// Build prefilled context with the failing ID
function buildPrefilledMessages(): Message[] {
const userMessage: Message = {
role: "user",
content: "Use the echo tool to echo 'hello'",
timestamp: Date.now() - 2000,
};
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{
type: "toolCall",
id: FAILING_TOOL_CALL_ID,
name: "echo",
arguments: { message: "hello" },
},
],
api: "openai-responses",
provider: "github-copilot",
model: "gpt-5.2-codex",
usage: {
input: 100,
output: 50,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 150,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "toolUse",
timestamp: Date.now() - 1500,
};
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: FAILING_TOOL_CALL_ID,
toolName: "echo",
content: [{ type: "text", text: "hello" }],
isError: false,
timestamp: Date.now() - 1000,
};
const followUpUser: Message = {
role: "user",
content: "Say hi",
timestamp: Date.now(),
};
return [userMessage, assistantMessage, toolResult, followUpUser];
}
it.skipIf(!openrouterKey)(
"openrouter should handle prefilled context with long pipe-separated IDs",
async () => {
const model = getModel("openrouter", "openai/gpt-5.2-codex");
const messages = buildPrefilledMessages();
const response = await completeSimple(
model,
{
systemPrompt: "You are a helpful assistant.",
messages,
tools: [echoTool],
},
{ apiKey: openrouterKey },
);
// Should NOT fail with "call_id too long" error
expect(
response.stopReason,
`OpenRouter error: ${response.errorMessage}`,
).not.toBe("error");
if (response.errorMessage) {
expect(response.errorMessage).not.toContain("call_id");
expect(response.errorMessage).not.toContain("too long");
}
},
30000,
);
it.skipIf(!codexToken)(
"openai-codex should handle prefilled context with long pipe-separated IDs",
async () => {
const model = getModel("openai-codex", "gpt-5.2-codex");
const messages = buildPrefilledMessages();
const response = await completeSimple(
model,
{
systemPrompt: "You are a helpful assistant.",
messages,
tools: [echoTool],
},
{ apiKey: codexToken },
);
// Should NOT fail with ID validation error
expect(
response.stopReason,
`Codex error: ${response.errorMessage}`,
).not.toBe("error");
if (response.errorMessage) {
expect(response.errorMessage).not.toContain("id");
expect(response.errorMessage).not.toContain("additional characters");
}
},
30000,
);
});

View file

@ -0,0 +1,412 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { complete } from "../src/stream.js";
import type { Api, Context, Model, StreamOptions, Tool } from "../src/types.js";
type StreamOptionsWithExtras = StreamOptions & Record<string, unknown>;
import {
hasAzureOpenAICredentials,
resolveAzureDeploymentName,
} from "./azure-utils.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
resolveApiKey("openai-codex"),
]);
const [
anthropicOAuthToken,
githubCopilotToken,
geminiCliToken,
antigravityToken,
openaiCodexToken,
] = oauthTokens;
// Simple calculate tool
const calculateSchema = Type.Object({
expression: Type.String({
description: "The mathematical expression to evaluate",
}),
});
const calculateTool: Tool = {
name: "calculate",
description: "Evaluate mathematical expressions",
parameters: calculateSchema,
};
async function testToolCallWithoutResult<TApi extends Api>(
model: Model<TApi>,
options: StreamOptionsWithExtras = {},
) {
// Step 1: Create context with the calculate tool
const context: Context = {
systemPrompt:
"You are a helpful assistant. Use the calculate tool when asked to perform calculations.",
messages: [],
tools: [calculateTool],
};
// Step 2: Ask the LLM to make a tool call
context.messages.push({
role: "user",
content: "Please calculate 25 * 18 using the calculate tool.",
timestamp: Date.now(),
});
// Step 3: Get the assistant's response (should contain a tool call)
const firstResponse = await complete(model, context, options);
context.messages.push(firstResponse);
console.log("First response:", JSON.stringify(firstResponse, null, 2));
// Verify the response contains a tool call
const hasToolCall = firstResponse.content.some(
(block) => block.type === "toolCall",
);
expect(hasToolCall).toBe(true);
if (!hasToolCall) {
throw new Error(
"Expected assistant to make a tool call, but none was found",
);
}
// Step 4: Send a user message WITHOUT providing tool result
// This simulates the scenario where a tool call was aborted/cancelled
context.messages.push({
role: "user",
content: "Never mind, just tell me what is 2+2?",
timestamp: Date.now(),
});
// Step 5: The fix should filter out the orphaned tool call, and the request should succeed
const secondResponse = await complete(model, context, options);
console.log("Second response:", JSON.stringify(secondResponse, null, 2));
// The request should succeed (not error) - that's the main thing we're testing
expect(secondResponse.stopReason).not.toBe("error");
// Should have some content in the response
expect(secondResponse.content.length).toBeGreaterThan(0);
// The LLM may choose to answer directly or make a new tool call - either is fine
// The important thing is it didn't fail with the orphaned tool call error
const textContent = secondResponse.content
.filter((block) => block.type === "text")
.map((block) => (block.type === "text" ? block.text : ""))
.join(" ");
const toolCalls = secondResponse.content.filter(
(block) => block.type === "toolCall",
).length;
expect(toolCalls || textContent.length).toBeGreaterThan(0);
console.log("Answer:", textContent);
// Verify the stop reason is either "stop" or "toolUse" (new tool call)
expect(["stop", "toolUse"]).toContain(secondResponse.stopReason);
}
describe("Tool Call Without Result Tests", () => {
// =========================================================================
// API Key-based providers
// =========================================================================
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider", () => {
const model = getModel("google", "gemini-2.5-flash");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Completions Provider",
() => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
)!;
void _compat;
const model: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
};
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
},
);
describe.skipIf(!process.env.OPENAI_API_KEY)(
"OpenAI Responses Provider",
() => {
const model = getModel("openai", "gpt-5-mini");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
},
);
describe.skipIf(!hasAzureOpenAICredentials())(
"Azure OpenAI Responses Provider",
() => {
const model = getModel("azure-openai-responses", "gpt-4o-mini");
const azureDeploymentName = resolveAzureDeploymentName(model.id);
const azureOptions = azureDeploymentName ? { azureDeploymentName } : {};
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model, azureOptions);
},
);
},
);
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider", () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider", () => {
const model = getModel("xai", "grok-3-fast");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.HF_TOKEN)("Hugging Face Provider", () => {
const model = getModel("huggingface", "moonshotai/Kimi-K2.5");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider", () => {
const model = getModel("zai", "glm-4.5-flash");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider", () => {
const model = getModel("mistral", "devstral-medium-latest");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.MINIMAX_API_KEY)("MiniMax Provider", () => {
const model = getModel("minimax", "MiniMax-M2.1");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.KIMI_API_KEY)("Kimi For Coding Provider", () => {
const model = getModel("kimi-coding", "kimi-k2-thinking");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
describe.skipIf(!process.env.AI_GATEWAY_API_KEY)(
"Vercel AI Gateway Provider",
() => {
const model = getModel("vercel-ai-gateway", "google/gemini-2.5-flash");
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
},
);
describe.skipIf(!hasBedrockCredentials())("Amazon Bedrock Provider", () => {
const model = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
it(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model);
},
);
});
// =========================================================================
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
// =========================================================================
describe("Anthropic OAuth Provider", () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
it.skipIf(!anthropicOAuthToken)(
"should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
await testToolCallWithoutResult(model, { apiKey: anthropicOAuthToken });
},
);
});
describe("GitHub Copilot Provider", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("github-copilot", "gpt-4o");
await testToolCallWithoutResult(model, { apiKey: githubCopilotToken });
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
await testToolCallWithoutResult(model, { apiKey: githubCopilotToken });
},
);
});
describe("Google Gemini CLI Provider", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
await testToolCallWithoutResult(model, { apiKey: geminiCliToken });
},
);
});
describe("Google Antigravity Provider", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
await testToolCallWithoutResult(model, { apiKey: antigravityToken });
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
await testToolCallWithoutResult(model, { apiKey: antigravityToken });
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
await testToolCallWithoutResult(model, { apiKey: antigravityToken });
},
);
});
describe("OpenAI Codex Provider", () => {
it.skipIf(!openaiCodexToken)(
"gpt-5.2-codex - should filter out tool calls without corresponding tool results",
{ retry: 3, timeout: 30000 },
async () => {
const model = getModel("openai-codex", "gpt-5.2-codex");
await testToolCallWithoutResult(model, { apiKey: openaiCodexToken });
},
);
});
});

View file

@ -0,0 +1,785 @@
/**
* Test totalTokens field across all providers.
*
* totalTokens represents the total number of tokens processed by the LLM,
* including input (with cache) and output (with thinking). This is the
* base for calculating context size for the next request.
*
* - OpenAI Completions: Uses native total_tokens field
* - OpenAI Responses: Uses native total_tokens field
* - Google: Uses native totalTokenCount field
* - Anthropic: Computed as input + output + cacheRead + cacheWrite
* - Other OpenAI-compatible providers: Uses native total_tokens field
*/
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { complete } from "../src/stream.js";
import type {
Api,
Context,
Model,
StreamOptions,
Usage,
} from "../src/types.js";
type StreamOptionsWithExtras = StreamOptions & Record<string, unknown>;
import {
hasAzureOpenAICredentials,
resolveAzureDeploymentName,
} from "./azure-utils.js";
import { hasBedrockCredentials } from "./bedrock-utils.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
resolveApiKey("openai-codex"),
]);
const [
anthropicOAuthToken,
githubCopilotToken,
geminiCliToken,
antigravityToken,
openaiCodexToken,
] = oauthTokens;
// Generate a long system prompt to trigger caching (>2k bytes for most providers)
const LONG_SYSTEM_PROMPT = `You are a helpful assistant. Be concise in your responses.
Here is some additional context that makes this system prompt long enough to trigger caching:
${Array(50)
.fill(
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
)
.join("\n\n")}
Remember: Always be helpful and concise.`;
async function testTotalTokensWithCache<TApi extends Api>(
llm: Model<TApi>,
options: StreamOptionsWithExtras = {},
): Promise<{ first: Usage; second: Usage }> {
// First request - no cache
const context1: Context = {
systemPrompt: LONG_SYSTEM_PROMPT,
messages: [
{
role: "user",
content: "What is 2 + 2? Reply with just the number.",
timestamp: Date.now(),
},
],
};
const response1 = await complete(llm, context1, options);
expect(response1.stopReason).toBe("stop");
// Second request - should trigger cache read (same system prompt, add conversation)
const context2: Context = {
systemPrompt: LONG_SYSTEM_PROMPT,
messages: [
...context1.messages,
response1, // Include previous assistant response
{
role: "user",
content: "What is 3 + 3? Reply with just the number.",
timestamp: Date.now(),
},
],
};
const response2 = await complete(llm, context2, options);
expect(response2.stopReason).toBe("stop");
return { first: response1.usage, second: response2.usage };
}
function logUsage(label: string, usage: Usage) {
const computed =
usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
console.log(` ${label}:`);
console.log(
` input: ${usage.input}, output: ${usage.output}, cacheRead: ${usage.cacheRead}, cacheWrite: ${usage.cacheWrite}`,
);
console.log(` totalTokens: ${usage.totalTokens}, computed: ${computed}`);
}
function assertTotalTokensEqualsComponents(usage: Usage) {
const computed =
usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
expect(usage.totalTokens).toBe(computed);
}
describe("totalTokens field", () => {
// =========================================================================
// Anthropic
// =========================================================================
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic (API Key)", () => {
it(
"claude-3-5-haiku - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
console.log(`\nAnthropic / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.ANTHROPIC_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
// Anthropic should have cache activity
const hasCache =
second.cacheRead > 0 || second.cacheWrite > 0 || first.cacheWrite > 0;
expect(hasCache).toBe(true);
},
);
});
describe("Anthropic (OAuth)", () => {
it.skipIf(!anthropicOAuthToken)(
"claude-sonnet-4 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("anthropic", "claude-sonnet-4-20250514");
console.log(`\nAnthropic OAuth / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: anthropicOAuthToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
// Anthropic should have cache activity
const hasCache =
second.cacheRead > 0 || second.cacheWrite > 0 || first.cacheWrite > 0;
expect(hasCache).toBe(true);
},
);
});
// =========================================================================
// OpenAI
// =========================================================================
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions", () => {
it(
"gpt-4o-mini - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-4o-mini",
)!;
void _compat;
const llm: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
};
console.log(`\nOpenAI Completions / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm);
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses", () => {
it(
"gpt-4o - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("openai", "gpt-4o");
console.log(`\nOpenAI Responses / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm);
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
describe.skipIf(!hasAzureOpenAICredentials())(
"Azure OpenAI Responses",
() => {
it(
"gpt-4o-mini - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("azure-openai-responses", "gpt-4o-mini");
const azureDeploymentName = resolveAzureDeploymentName(llm.id);
const azureOptions = azureDeploymentName
? { azureDeploymentName }
: {};
console.log(`\nAzure OpenAI Responses / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(
llm,
azureOptions,
);
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
},
);
// =========================================================================
// Google
// =========================================================================
describe.skipIf(!process.env.GEMINI_API_KEY)("Google", () => {
it(
"gemini-2.0-flash - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("google", "gemini-2.0-flash");
console.log(`\nGoogle / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm);
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// xAI
// =========================================================================
describe.skipIf(!process.env.XAI_API_KEY)("xAI", () => {
it(
"grok-3-fast - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("xai", "grok-3-fast");
console.log(`\nxAI / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.XAI_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Groq
// =========================================================================
describe.skipIf(!process.env.GROQ_API_KEY)("Groq", () => {
it(
"openai/gpt-oss-120b - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("groq", "openai/gpt-oss-120b");
console.log(`\nGroq / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.GROQ_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Cerebras
// =========================================================================
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras", () => {
it(
"gpt-oss-120b - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("cerebras", "gpt-oss-120b");
console.log(`\nCerebras / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.CEREBRAS_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Hugging Face
// =========================================================================
describe.skipIf(!process.env.HF_TOKEN)("Hugging Face", () => {
it(
"Kimi-K2.5 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("huggingface", "moonshotai/Kimi-K2.5");
console.log(`\nHugging Face / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.HF_TOKEN,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// z.ai
// =========================================================================
describe.skipIf(!process.env.ZAI_API_KEY)("z.ai", () => {
it(
"glm-4.5-flash - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("zai", "glm-4.5-flash");
console.log(`\nz.ai / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.ZAI_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Mistral
// =========================================================================
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral", () => {
it(
"devstral-medium-latest - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("mistral", "devstral-medium-latest");
console.log(`\nMistral / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.MISTRAL_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// MiniMax
// =========================================================================
describe.skipIf(!process.env.MINIMAX_API_KEY)("MiniMax", () => {
it(
"MiniMax-M2.1 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("minimax", "MiniMax-M2.1");
console.log(`\nMiniMax / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.MINIMAX_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Kimi For Coding
// =========================================================================
describe.skipIf(!process.env.KIMI_API_KEY)("Kimi For Coding", () => {
it(
"kimi-k2-thinking - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("kimi-coding", "kimi-k2-thinking");
console.log(`\nKimi For Coding / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.KIMI_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Vercel AI Gateway
// =========================================================================
describe.skipIf(!process.env.AI_GATEWAY_API_KEY)("Vercel AI Gateway", () => {
it(
"google/gemini-2.5-flash - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("vercel-ai-gateway", "google/gemini-2.5-flash");
console.log(`\nVercel AI Gateway / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.AI_GATEWAY_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// OpenRouter - Multiple backend providers
// =========================================================================
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter", () => {
it(
"anthropic/claude-sonnet-4 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("openrouter", "anthropic/claude-sonnet-4");
console.log(`\nOpenRouter / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.OPENROUTER_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it(
"deepseek/deepseek-chat - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("openrouter", "deepseek/deepseek-chat");
console.log(`\nOpenRouter / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.OPENROUTER_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it(
"mistralai/mistral-small-3.2-24b-instruct - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel(
"openrouter",
"mistralai/mistral-small-3.2-24b-instruct",
);
console.log(`\nOpenRouter / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.OPENROUTER_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it(
"google/gemini-2.0-flash-001 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("openrouter", "google/gemini-2.0-flash-001");
console.log(`\nOpenRouter / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.OPENROUTER_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it(
"meta-llama/llama-4-maverick - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("openrouter", "meta-llama/llama-4-maverick");
console.log(`\nOpenRouter / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: process.env.OPENROUTER_API_KEY,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// GitHub Copilot (OAuth)
// =========================================================================
describe("GitHub Copilot (OAuth)", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("github-copilot", "gpt-4o");
console.log(`\nGitHub Copilot / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: githubCopilotToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("github-copilot", "claude-sonnet-4");
console.log(`\nGitHub Copilot / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: githubCopilotToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Google Gemini CLI (OAuth)
// =========================================================================
describe("Google Gemini CLI (OAuth)", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("google-gemini-cli", "gemini-2.5-flash");
console.log(`\nGoogle Gemini CLI / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: geminiCliToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// Google Antigravity (OAuth)
// =========================================================================
describe("Google Antigravity (OAuth)", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("google-antigravity", "gemini-3-flash");
console.log(`\nGoogle Antigravity / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: antigravityToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("google-antigravity", "claude-sonnet-4-5");
console.log(`\nGoogle Antigravity / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: antigravityToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("google-antigravity", "gpt-oss-120b-medium");
console.log(`\nGoogle Antigravity / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: antigravityToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
describe.skipIf(!hasBedrockCredentials())("Amazon Bedrock", () => {
it(
"claude-sonnet-4-5 - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel(
"amazon-bedrock",
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
);
console.log(`\nAmazon Bedrock / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm);
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
// =========================================================================
// OpenAI Codex (OAuth)
// =========================================================================
describe("OpenAI Codex (OAuth)", () => {
it.skipIf(!openaiCodexToken)(
"gpt-5.2-codex - should return totalTokens equal to sum of components",
{ retry: 3, timeout: 60000 },
async () => {
const llm = getModel("openai-codex", "gpt-5.2-codex");
console.log(`\nOpenAI Codex / ${llm.id}:`);
const { first, second } = await testTotalTokensWithCache(llm, {
apiKey: openaiCodexToken,
});
logUsage("First request", first);
logUsage("Second request", second);
assertTotalTokensEqualsComponents(first);
assertTotalTokensEqualsComponents(second);
},
);
});
});

View file

@ -0,0 +1,140 @@
import { describe, expect, it } from "vitest";
import { transformMessages } from "../src/providers/transform-messages.js";
import type {
AssistantMessage,
Message,
Model,
ToolCall,
} from "../src/types.js";
// Normalize function matching what anthropic.ts uses
function anthropicNormalizeToolCallId(
id: string,
_model: Model<"anthropic-messages">,
_source: AssistantMessage,
): string {
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
}
function makeCopilotClaudeModel(): Model<"anthropic-messages"> {
return {
id: "claude-sonnet-4",
name: "Claude Sonnet 4",
api: "anthropic-messages",
provider: "github-copilot",
baseUrl: "https://api.individual.githubcopilot.com",
reasoning: true,
input: ["text", "image"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 16000,
};
}
describe("OpenAI to Anthropic session migration for Copilot Claude", () => {
it("converts thinking blocks to plain text when source model differs", () => {
const model = makeCopilotClaudeModel();
const messages: Message[] = [
{ role: "user", content: "hello", timestamp: Date.now() },
{
role: "assistant",
content: [
{
type: "thinking",
thinking: "Let me think about this...",
thinkingSignature: "reasoning_content",
},
{ type: "text", text: "Hi there!" },
],
api: "openai-completions",
provider: "github-copilot",
model: "gpt-4o",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
},
];
const result = transformMessages(
messages,
model,
anthropicNormalizeToolCallId,
);
const assistantMsg = result.find(
(m) => m.role === "assistant",
) as AssistantMessage;
// Thinking block should be converted to text since models differ
const textBlocks = assistantMsg.content.filter((b) => b.type === "text");
const thinkingBlocks = assistantMsg.content.filter(
(b) => b.type === "thinking",
);
expect(thinkingBlocks).toHaveLength(0);
expect(textBlocks.length).toBeGreaterThanOrEqual(2);
});
it("removes thoughtSignature from tool calls when migrating between models", () => {
const model = makeCopilotClaudeModel();
const messages: Message[] = [
{ role: "user", content: "run a command", timestamp: Date.now() },
{
role: "assistant",
content: [
{
type: "toolCall",
id: "call_123",
name: "bash",
arguments: { command: "ls" },
thoughtSignature: JSON.stringify({
type: "reasoning.encrypted",
id: "call_123",
data: "encrypted",
}),
},
],
api: "openai-responses",
provider: "github-copilot",
model: "gpt-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "toolUse",
timestamp: Date.now(),
},
{
role: "toolResult",
toolCallId: "call_123",
toolName: "bash",
content: [{ type: "text", text: "output" }],
isError: false,
timestamp: Date.now(),
},
];
const result = transformMessages(
messages,
model,
anthropicNormalizeToolCallId,
);
const assistantMsg = result.find(
(m) => m.role === "assistant",
) as AssistantMessage;
const toolCall = assistantMsg.content.find(
(b) => b.type === "toolCall",
) as ToolCall;
expect(toolCall.thoughtSignature).toBeUndefined();
});
});

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,81 @@
import { describe, expect, it } from "vitest";
import { getModel } from "../src/models.js";
import { stream } from "../src/stream.js";
import type { Context, Model } from "../src/types.js";
function makeContext(): Context {
return {
messages: [
{
role: "user",
content: `What is ${(Math.random() * 100) | 0} + ${(Math.random() * 100) | 0}? Think step by step.`,
timestamp: Date.now(),
},
],
};
}
describe.skipIf(!process.env.OPENAI_API_KEY)("xhigh reasoning", () => {
describe("codex-max (supports xhigh)", () => {
// Note: codex models only support the responses API, not chat completions
it("should work with openai-responses", async () => {
const model = getModel("openai", "gpt-5.1-codex-max");
const s = stream(model, makeContext(), { reasoningEffort: "xhigh" });
let hasThinking = false;
for await (const event of s) {
if (
event.type === "thinking_start" ||
event.type === "thinking_delta"
) {
hasThinking = true;
}
}
const response = await s.result();
expect(response.stopReason, `Error: ${response.errorMessage}`).toBe(
"stop",
);
expect(response.content.some((b) => b.type === "text")).toBe(true);
expect(
hasThinking || response.content.some((b) => b.type === "thinking"),
).toBe(true);
});
});
describe("gpt-5-mini (does not support xhigh)", () => {
it("should error with openai-responses when using xhigh", async () => {
const model = getModel("openai", "gpt-5-mini");
const s = stream(model, makeContext(), { reasoningEffort: "xhigh" });
for await (const _ of s) {
// drain events
}
const response = await s.result();
expect(response.stopReason).toBe("error");
expect(response.errorMessage).toContain("xhigh");
});
it("should error with openai-completions when using xhigh", async () => {
const { compat: _compat, ...baseModel } = getModel(
"openai",
"gpt-5-mini",
);
void _compat;
const model: Model<"openai-completions"> = {
...baseModel,
api: "openai-completions",
};
const s = stream(model, makeContext(), { reasoningEffort: "xhigh" });
for await (const _ of s) {
// drain events
}
const response = await s.result();
expect(response.stopReason).toBe("error");
expect(response.errorMessage).toContain("xhigh");
});
});
});

View file

@ -0,0 +1,30 @@
import { describe, expect, it } from "vitest";
import { MODELS } from "../src/models.generated.js";
import { complete } from "../src/stream.js";
import type { Model } from "../src/types.js";
describe.skipIf(!process.env.OPENCODE_API_KEY)(
"OpenCode Models Smoke Test",
() => {
const providers = [
{ key: "opencode", label: "OpenCode Zen" },
{ key: "opencode-go", label: "OpenCode Go" },
] as const;
providers.forEach(({ key, label }) => {
const providerModels = Object.values(MODELS[key]);
providerModels.forEach((model) => {
it(`${label}: ${model.id}`, async () => {
const response = await complete(model as Model<any>, {
messages: [
{ role: "user", content: "Say hello.", timestamp: Date.now() },
],
});
expect(response.content).toBeTruthy();
expect(response.stopReason).toBe("stop");
}, 60000);
});
});
},
);

View file

@ -0,0 +1,9 @@
{
"extends": "../../tsconfig.base.json",
"compilerOptions": {
"outDir": "./dist",
"rootDir": "./src"
},
"include": ["src/**/*.ts"],
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
}

Some files were not shown because too many files have changed in this diff Show more