feat: custom provider support with streamSimple

- Add resetApiProviders() to clear and re-register built-in providers
- Add createAssistantMessageEventStream() factory for extensions
- Add streamSimple support in ProviderConfig for custom API implementations
- Call resetApiProviders() on /reload to clean up extension providers
- Add custom-provider.md documentation
- Add custom-provider.ts example with full Anthropic implementation
- Update extensions.md with streamSimple config option
This commit is contained in:
Mario Zechner 2026-01-24 23:11:20 +01:00
parent c06163bc59
commit 177c694406
11 changed files with 1243 additions and 69 deletions

View file

@ -5,6 +5,8 @@
### Added
- Added `azure-openai-responses` provider support for Azure OpenAI Responses API. ([#890](https://github.com/badlogic/pi-mono/pull/890) by [@markusylisiurunen](https://github.com/markusylisiurunen))
- Added `createAssistantMessageEventStream()` factory function for use in extensions.
- Added `resetApiProviders()` to clear and re-register built-in API providers.
### Changed

View file

@ -8,6 +8,7 @@ export * from "./providers/google-gemini-cli.js";
export * from "./providers/google-vertex.js";
export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./providers/register-builtins.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";

View file

@ -1,4 +1,4 @@
import { registerApiProvider } from "../api-registry.js";
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
import { streamBedrock, streamSimpleBedrock } from "./amazon-bedrock.js";
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
import { streamAzureOpenAIResponses, streamSimpleAzureOpenAIResponses } from "./azure-openai-responses.js";
@ -9,56 +9,65 @@ import { streamOpenAICodexResponses, streamSimpleOpenAICodexResponses } from "./
import { streamOpenAICompletions, streamSimpleOpenAICompletions } from "./openai-completions.js";
import { streamOpenAIResponses, streamSimpleOpenAIResponses } from "./openai-responses.js";
registerApiProvider({
api: "anthropic-messages",
stream: streamAnthropic,
streamSimple: streamSimpleAnthropic,
});
export function registerBuiltInApiProviders(): void {
registerApiProvider({
api: "anthropic-messages",
stream: streamAnthropic,
streamSimple: streamSimpleAnthropic,
});
registerApiProvider({
api: "openai-completions",
stream: streamOpenAICompletions,
streamSimple: streamSimpleOpenAICompletions,
});
registerApiProvider({
api: "openai-completions",
stream: streamOpenAICompletions,
streamSimple: streamSimpleOpenAICompletions,
});
registerApiProvider({
api: "openai-responses",
stream: streamOpenAIResponses,
streamSimple: streamSimpleOpenAIResponses,
});
registerApiProvider({
api: "openai-responses",
stream: streamOpenAIResponses,
streamSimple: streamSimpleOpenAIResponses,
});
registerApiProvider({
api: "azure-openai-responses",
stream: streamAzureOpenAIResponses,
streamSimple: streamSimpleAzureOpenAIResponses,
});
registerApiProvider({
api: "azure-openai-responses",
stream: streamAzureOpenAIResponses,
streamSimple: streamSimpleAzureOpenAIResponses,
});
registerApiProvider({
api: "openai-codex-responses",
stream: streamOpenAICodexResponses,
streamSimple: streamSimpleOpenAICodexResponses,
});
registerApiProvider({
api: "openai-codex-responses",
stream: streamOpenAICodexResponses,
streamSimple: streamSimpleOpenAICodexResponses,
});
registerApiProvider({
api: "google-generative-ai",
stream: streamGoogle,
streamSimple: streamSimpleGoogle,
});
registerApiProvider({
api: "google-generative-ai",
stream: streamGoogle,
streamSimple: streamSimpleGoogle,
});
registerApiProvider({
api: "google-gemini-cli",
stream: streamGoogleGeminiCli,
streamSimple: streamSimpleGoogleGeminiCli,
});
registerApiProvider({
api: "google-gemini-cli",
stream: streamGoogleGeminiCli,
streamSimple: streamSimpleGoogleGeminiCli,
});
registerApiProvider({
api: "google-vertex",
stream: streamGoogleVertex,
streamSimple: streamSimpleGoogleVertex,
});
registerApiProvider({
api: "google-vertex",
stream: streamGoogleVertex,
streamSimple: streamSimpleGoogleVertex,
});
registerApiProvider({
api: "bedrock-converse-stream",
stream: streamBedrock,
streamSimple: streamSimpleBedrock,
});
registerApiProvider({
api: "bedrock-converse-stream",
stream: streamBedrock,
streamSimple: streamSimpleBedrock,
});
}
export function resetApiProviders(): void {
clearApiProviders();
registerBuiltInApiProviders();
}
registerBuiltInApiProviders();

View file

@ -80,3 +80,8 @@ export class AssistantMessageEventStream extends EventStream<AssistantMessageEve
);
}
}
/** Factory function for AssistantMessageEventStream (for use in extensions) */
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
return new AssistantMessageEventStream();
}

View file

@ -24,6 +24,8 @@
- Package deduplication: if same package appears in global and project settings, project wins ([#645](https://github.com/badlogic/pi-mono/issues/645))
- Unified collision reporting with `ResourceDiagnostic` type for all resource types ([#645](https://github.com/badlogic/pi-mono/issues/645))
- Show provider alongside the model in the footer if multiple providers are available
- Custom provider support via `pi.registerProvider()` with `streamSimple` for custom API implementations
- Added `custom-provider.ts` example extension demonstrating custom Anthropic provider with OAuth
### Fixed

View file

@ -0,0 +1,547 @@
# Custom Providers
Extensions can register custom model providers via `pi.registerProvider()`. This enables:
- **Proxies** - Route requests through corporate proxies or API gateways
- **Custom endpoints** - Use self-hosted or private model deployments
- **OAuth/SSO** - Add authentication flows for enterprise providers
- **Custom APIs** - Implement streaming for non-standard LLM APIs
## Table of Contents
- [Quick Reference](#quick-reference)
- [Override Existing Provider](#override-existing-provider)
- [Register New Provider](#register-new-provider)
- [OAuth Support](#oauth-support)
- [Custom Streaming API](#custom-streaming-api)
- [Config Reference](#config-reference)
- [Model Definition Reference](#model-definition-reference)
## Quick Reference
```typescript
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
export default function (pi: ExtensionAPI) {
// Override baseUrl for existing provider
pi.registerProvider("anthropic", {
baseUrl: "https://proxy.example.com"
});
// Register new provider with models
pi.registerProvider("my-provider", {
baseUrl: "https://api.example.com",
apiKey: "MY_API_KEY",
api: "openai-completions",
models: [
{
id: "my-model",
name: "My Model",
reasoning: false,
input: ["text", "image"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 4096
}
]
});
}
```
## Override Existing Provider
The simplest use case: redirect an existing provider through a proxy.
```typescript
// All Anthropic requests now go through your proxy
pi.registerProvider("anthropic", {
baseUrl: "https://proxy.example.com"
});
// Add custom headers to OpenAI requests
pi.registerProvider("openai", {
headers: {
"X-Custom-Header": "value"
}
});
// Both baseUrl and headers
pi.registerProvider("google", {
baseUrl: "https://ai-gateway.corp.com/google",
headers: {
"X-Corp-Auth": "CORP_AUTH_TOKEN" // env var or literal
}
});
```
When only `baseUrl` and/or `headers` are provided (no `models`), all existing models for that provider are preserved with the new endpoint.
## Register New Provider
To add a completely new provider, specify `models` along with the required configuration.
```typescript
pi.registerProvider("my-llm", {
baseUrl: "https://api.my-llm.com/v1",
apiKey: "MY_LLM_API_KEY", // env var name or literal value
api: "openai-completions", // which streaming API to use
models: [
{
id: "my-llm-large",
name: "My LLM Large",
reasoning: true, // supports extended thinking
input: ["text", "image"],
cost: {
input: 3.0, // $/million tokens
output: 15.0,
cacheRead: 0.3,
cacheWrite: 3.75
},
contextWindow: 200000,
maxTokens: 16384
},
{
id: "my-llm-small",
name: "My LLM Small",
reasoning: false,
input: ["text"],
cost: { input: 0.25, output: 1.25, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192
}
]
});
```
When `models` is provided, it **replaces** all existing models for that provider.
### API Types
The `api` field determines which streaming implementation is used:
| API | Use for |
|-----|---------|
| `anthropic-messages` | Anthropic Claude API and compatibles |
| `openai-completions` | OpenAI Chat Completions API and compatibles |
| `openai-responses` | OpenAI Responses API |
| `azure-openai-responses` | Azure OpenAI Responses API |
| `openai-codex-responses` | OpenAI Codex Responses API |
| `google-generative-ai` | Google Generative AI API |
| `google-gemini-cli` | Google Cloud Code Assist API |
| `google-vertex` | Google Vertex AI API |
| `bedrock-converse-stream` | Amazon Bedrock Converse API |
Most OpenAI-compatible providers work with `openai-completions`. Use `compat` for quirks:
```typescript
models: [{
id: "custom-model",
// ...
compat: {
supportsDeveloperRole: false, // use "system" instead of "developer"
supportsReasoningEffort: false, // disable reasoning_effort param
maxTokensField: "max_tokens", // instead of "max_completion_tokens"
requiresToolResultName: true, // tool results need name field
requiresMistralToolIds: true // tool IDs must be 9 alphanumeric chars
}
}]
```
### Auth Header
If your provider expects `Authorization: Bearer <key>` but doesn't use a standard API, set `authHeader: true`:
```typescript
pi.registerProvider("custom-api", {
baseUrl: "https://api.example.com",
apiKey: "MY_API_KEY",
authHeader: true, // adds Authorization: Bearer header
api: "openai-completions",
models: [...]
});
```
## OAuth Support
Add OAuth/SSO authentication that integrates with `/login`:
```typescript
import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai";
pi.registerProvider("corporate-ai", {
baseUrl: "https://ai.corp.com/v1",
api: "openai-responses",
models: [
{
id: "corp-claude",
name: "Corporate Claude",
reasoning: true,
input: ["text", "image"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 200000,
maxTokens: 16384
}
],
oauth: {
name: "Corporate AI (SSO)",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
// Option 1: Browser-based OAuth
callbacks.onAuth({ url: "https://sso.corp.com/authorize?..." });
// Option 2: Device code flow
callbacks.onDeviceCode({
userCode: "ABCD-1234",
verificationUri: "https://sso.corp.com/device"
});
// Option 3: Prompt for token/code
const code = await callbacks.onPrompt({ message: "Enter SSO code:" });
// Exchange for tokens (your implementation)
const tokens = await exchangeCodeForTokens(code);
return {
refresh: tokens.refreshToken,
access: tokens.accessToken,
expires: Date.now() + tokens.expiresIn * 1000
};
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const tokens = await refreshAccessToken(credentials.refresh);
return {
refresh: tokens.refreshToken ?? credentials.refresh,
access: tokens.accessToken,
expires: Date.now() + tokens.expiresIn * 1000
};
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
// Optional: modify models based on user's subscription
modifyModels(models, credentials) {
// e.g., update baseUrl based on user's region
const region = decodeRegionFromToken(credentials.access);
return models.map(m => ({
...m,
baseUrl: `https://${region}.ai.corp.com/v1`
}));
}
}
});
```
After registration, users can authenticate via `/login corporate-ai`.
### OAuthLoginCallbacks
The `callbacks` object provides three ways to authenticate:
```typescript
interface OAuthLoginCallbacks {
// Open URL in browser (for OAuth redirects)
onAuth(params: { url: string }): void;
// Show device code (for device authorization flow)
onDeviceCode(params: { userCode: string; verificationUri: string }): void;
// Prompt user for input (for manual token entry)
onPrompt(params: { message: string }): Promise<string>;
}
```
### OAuthCredentials
Credentials are persisted in `~/.pi/agent/auth.json`:
```typescript
interface OAuthCredentials {
refresh: string; // Refresh token (for refreshToken())
access: string; // Access token (returned by getApiKey())
expires: number; // Expiration timestamp in milliseconds
}
```
## Custom Streaming API
For providers with non-standard APIs, implement `streamSimple`:
```typescript
import type {
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
Api
} from "@mariozechner/pi-ai";
import { createAssistantMessageEventStream } from "@mariozechner/pi-ai";
pi.registerProvider("custom-llm", {
baseUrl: "https://api.custom-llm.com",
apiKey: "CUSTOM_LLM_KEY",
api: "custom-llm-api", // your custom API identifier
models: [
{
id: "custom-model",
name: "Custom Model",
reasoning: false,
input: ["text"],
cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 32000,
maxTokens: 4096
}
],
streamSimple(
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions
): AssistantMessageEventStream {
return createAssistantMessageEventStream(async function* (signal) {
// Convert context to your API format
const messages = context.messages.map(m => ({
role: m.role,
content: typeof m.content === "string"
? m.content
: m.content.filter(c => c.type === "text").map(c => c.text).join("")
}));
// Make streaming request
const response = await fetch(`${model.baseUrl}/chat`, {
method: "POST",
headers: {
"Authorization": `Bearer ${options?.apiKey}`,
"Content-Type": "application/json"
},
body: JSON.stringify({
model: model.id,
messages,
stream: true
}),
signal
});
if (!response.ok) {
throw new Error(`API error: ${response.status}`);
}
// Yield start event
yield { type: "start" };
// Parse SSE stream
const reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
let contentIndex = 0;
let textStarted = false;
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() ?? "";
for (const line of lines) {
if (!line.startsWith("data: ")) continue;
const data = line.slice(6);
if (data === "[DONE]") continue;
const chunk = JSON.parse(data);
const delta = chunk.choices?.[0]?.delta?.content;
if (delta) {
if (!textStarted) {
yield { type: "text_start", contentIndex };
textStarted = true;
}
yield { type: "text_delta", contentIndex, delta };
}
}
}
if (textStarted) {
yield { type: "text_end", contentIndex };
}
// Yield usage if available
yield {
type: "usage",
usage: {
input: 0, // fill from response if available
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
}
};
// Yield done
yield { type: "done", reason: "stop" };
});
}
});
```
### Event Types
Your generator must yield events in this order:
1. `{ type: "start" }` - Stream started
2. Content events (repeatable, in order):
- `{ type: "text_start", contentIndex }` - Text block started
- `{ type: "text_delta", contentIndex, delta }` - Text chunk
- `{ type: "text_end", contentIndex }` - Text block ended
- `{ type: "thinking_start", contentIndex }` - Thinking block started
- `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk
- `{ type: "thinking_end", contentIndex }` - Thinking block ended
- `{ type: "toolcall_start", contentIndex }` - Tool call started
- `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk
- `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended
3. `{ type: "usage", usage }` - Token usage (optional but recommended)
4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended
### Reasoning Support
For models with extended thinking, yield thinking events:
```typescript
if (chunk.thinking) {
if (!thinkingStarted) {
yield { type: "thinking_start", contentIndex: thinkingIndex };
thinkingStarted = true;
}
yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking };
}
```
### Tool Calls
For function calling support, yield tool call events:
```typescript
if (chunk.tool_calls) {
for (const tc of chunk.tool_calls) {
if (tc.index !== currentToolIndex) {
if (currentToolIndex >= 0) {
yield {
type: "toolcall_end",
contentIndex: currentToolIndex,
toolCall: {
type: "toolCall",
id: currentToolId,
name: currentToolName,
arguments: JSON.parse(currentToolArgs)
}
};
}
currentToolIndex = tc.index;
currentToolId = tc.id;
currentToolName = tc.function.name;
currentToolArgs = "";
yield { type: "toolcall_start", contentIndex: tc.index };
}
if (tc.function.arguments) {
currentToolArgs += tc.function.arguments;
yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments };
}
}
}
```
## Config Reference
```typescript
interface ProviderConfig {
/** API endpoint URL. Required when defining models. */
baseUrl?: string;
/** API key or environment variable name. Required when defining models (unless oauth). */
apiKey?: string;
/** API type for streaming. Required at provider or model level when defining models. */
api?: Api;
/** Custom streaming implementation for non-standard APIs. */
streamSimple?: (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions
) => AssistantMessageEventStream;
/** Custom headers to include in requests. Values can be env var names. */
headers?: Record<string, string>;
/** If true, adds Authorization: Bearer header with the resolved API key. */
authHeader?: boolean;
/** Models to register. If provided, replaces all existing models for this provider. */
models?: ProviderModelConfig[];
/** OAuth provider for /login support. */
oauth?: {
name: string;
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
getApiKey(credentials: OAuthCredentials): string;
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
};
}
```
## Model Definition Reference
```typescript
interface ProviderModelConfig {
/** Model ID (e.g., "claude-sonnet-4-20250514"). */
id: string;
/** Display name (e.g., "Claude 4 Sonnet"). */
name: string;
/** API type override for this specific model. */
api?: Api;
/** Whether the model supports extended thinking. */
reasoning: boolean;
/** Supported input types. */
input: ("text" | "image")[];
/** Cost per million tokens (for usage tracking). */
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
/** Maximum context window size in tokens. */
contextWindow: number;
/** Maximum output tokens. */
maxTokens: number;
/** Custom headers for this specific model. */
headers?: Record<string, string>;
/** OpenAI compatibility settings for openai-completions API. */
compat?: {
supportsStore?: boolean;
supportsDeveloperRole?: boolean;
supportsReasoningEffort?: boolean;
supportsUsageInStreaming?: boolean;
maxTokensField?: "max_completion_tokens" | "max_tokens";
requiresToolResultName?: boolean;
requiresAssistantAfterToolResult?: boolean;
requiresThinkingAsText?: boolean;
requiresMistralToolIds?: boolean;
thinkingFormat?: "openai" | "zai";
};
}
```

View file

@ -1206,6 +1206,9 @@ pi.registerProvider("corporate-ai", {
- `authHeader` - If true, adds `Authorization: Bearer` header automatically.
- `models` - Array of model definitions. If provided, replaces all existing models for this provider.
- `oauth` - OAuth provider config for `/login` support. When provided, the provider appears in the login menu.
- `streamSimple` - Custom streaming implementation for non-standard APIs.
See [custom-provider.md](custom-provider.md) for advanced topics: custom streaming APIs, OAuth details, model definition reference.
## State Management

View file

@ -0,0 +1,601 @@
/**
* Custom Provider Example
*
* Demonstrates registering a custom provider with:
* - Custom API identifier ("custom-anthropic-api")
* - Custom streamSimple implementation
* - OAuth support for /login
* - API key support via environment variable
* - Two model definitions
*
* Usage:
* # With OAuth (run /login custom-anthropic first)
* pi -e ./custom-provider.ts
*
* # With API key
* CUSTOM_ANTHROPIC_API_KEY=sk-ant-... pi -e ./custom-provider.ts
*
* Then use /model to select custom-anthropic/claude-sonnet-4-5
*/
import Anthropic from "@anthropic-ai/sdk";
import type { ContentBlockParam, MessageCreateParamsStreaming } from "@anthropic-ai/sdk/resources/messages.js";
import {
type Api,
type AssistantMessage,
type AssistantMessageEventStream,
type Context,
calculateCost,
createAssistantMessageEventStream,
type ImageContent,
type Message,
type Model,
type OAuthCredentials,
type OAuthLoginCallbacks,
type SimpleStreamOptions,
type StopReason,
type TextContent,
type ThinkingContent,
type Tool,
type ToolCall,
type ToolResultMessage,
} from "@mariozechner/pi-ai";
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
// =============================================================================
// OAuth Implementation (copied from packages/ai/src/utils/oauth/anthropic.ts)
// =============================================================================
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
const SCOPES = "org:create_api_key user:profile user:inference";
async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
const array = new Uint8Array(32);
crypto.getRandomValues(array);
const verifier = btoa(String.fromCharCode(...array))
.replace(/\+/g, "-")
.replace(/\//g, "_")
.replace(/=+$/, "");
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const hash = await crypto.subtle.digest("SHA-256", data);
const challenge = btoa(String.fromCharCode(...new Uint8Array(hash)))
.replace(/\+/g, "-")
.replace(/\//g, "_")
.replace(/=+$/, "");
return { verifier, challenge };
}
async function loginAnthropic(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
const authParams = new URLSearchParams({
code: "true",
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES,
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
});
callbacks.onAuth({ url: `${AUTHORIZE_URL}?${authParams.toString()}` });
const authCode = await callbacks.onPrompt({ message: "Paste the authorization code:" });
const [code, state] = authCode.split("#");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code,
state,
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
throw new Error(`Token exchange failed: ${await tokenResponse.text()}`);
}
const data = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
return {
refresh: data.refresh_token,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
};
}
async function refreshAnthropicToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
grant_type: "refresh_token",
client_id: CLIENT_ID,
refresh_token: credentials.refresh,
}),
});
if (!response.ok) {
throw new Error(`Token refresh failed: ${await response.text()}`);
}
const data = (await response.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
return {
refresh: data.refresh_token,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
};
}
// =============================================================================
// Streaming Implementation (simplified from packages/ai/src/providers/anthropic.ts)
// =============================================================================
// Claude Code tool names for OAuth stealth mode
const claudeCodeTools = [
"Read",
"Write",
"Edit",
"Bash",
"Grep",
"Glob",
"AskUserQuestion",
"TodoWrite",
"WebFetch",
"WebSearch",
];
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
const toClaudeCodeName = (name: string) => ccToolLookup.get(name.toLowerCase()) ?? name;
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
const lowerName = name.toLowerCase();
const matched = tools?.find((t) => t.name.toLowerCase() === lowerName);
return matched?.name ?? name;
};
function isOAuthToken(apiKey: string): boolean {
return apiKey.includes("sk-ant-oat");
}
function sanitizeSurrogates(text: string): string {
return text.replace(/[\uD800-\uDFFF]/g, "\uFFFD");
}
function convertContentBlocks(
content: (TextContent | ImageContent)[],
): string | Array<{ type: "text"; text: string } | { type: "image"; source: any }> {
const hasImages = content.some((c) => c.type === "image");
if (!hasImages) {
return sanitizeSurrogates(content.map((c) => (c as TextContent).text).join("\n"));
}
const blocks = content.map((block) => {
if (block.type === "text") {
return { type: "text" as const, text: sanitizeSurrogates(block.text) };
}
return {
type: "image" as const,
source: {
type: "base64" as const,
media_type: block.mimeType,
data: block.data,
},
};
});
if (!blocks.some((b) => b.type === "text")) {
blocks.unshift({ type: "text" as const, text: "(see attached image)" });
}
return blocks;
}
function convertMessages(messages: Message[], isOAuth: boolean, _tools?: Tool[]): any[] {
const params: any[] = [];
for (let i = 0; i < messages.length; i++) {
const msg = messages[i];
if (msg.role === "user") {
if (typeof msg.content === "string") {
if (msg.content.trim()) {
params.push({ role: "user", content: sanitizeSurrogates(msg.content) });
}
} else {
const blocks: ContentBlockParam[] = msg.content.map((item) =>
item.type === "text"
? { type: "text" as const, text: sanitizeSurrogates(item.text) }
: {
type: "image" as const,
source: { type: "base64" as const, media_type: item.mimeType as any, data: item.data },
},
);
if (blocks.length > 0) {
params.push({ role: "user", content: blocks });
}
}
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text" && block.text.trim()) {
blocks.push({ type: "text", text: sanitizeSurrogates(block.text) });
} else if (block.type === "thinking" && block.thinking.trim()) {
if ((block as ThinkingContent).thinkingSignature) {
blocks.push({
type: "thinking" as any,
thinking: sanitizeSurrogates(block.thinking),
signature: (block as ThinkingContent).thinkingSignature!,
});
} else {
blocks.push({ type: "text", text: sanitizeSurrogates(block.thinking) });
}
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: isOAuth ? toClaudeCodeName(block.name) : block.name,
input: block.arguments,
});
}
}
if (blocks.length > 0) {
params.push({ role: "assistant", content: blocks });
}
} else if (msg.role === "toolResult") {
const toolResults: any[] = [];
toolResults.push({
type: "tool_result",
tool_use_id: msg.toolCallId,
content: convertContentBlocks(msg.content),
is_error: msg.isError,
});
let j = i + 1;
while (j < messages.length && messages[j].role === "toolResult") {
const nextMsg = messages[j] as ToolResultMessage;
toolResults.push({
type: "tool_result",
tool_use_id: nextMsg.toolCallId,
content: convertContentBlocks(nextMsg.content),
is_error: nextMsg.isError,
});
j++;
}
i = j - 1;
params.push({ role: "user", content: toolResults });
}
}
// Add cache control to last user message
if (params.length > 0) {
const last = params[params.length - 1];
if (last.role === "user" && Array.isArray(last.content)) {
const lastBlock = last.content[last.content.length - 1];
if (lastBlock) {
lastBlock.cache_control = { type: "ephemeral" };
}
}
}
return params;
}
function convertTools(tools: Tool[], isOAuth: boolean): any[] {
return tools.map((tool) => ({
name: isOAuth ? toClaudeCodeName(tool.name) : tool.name,
description: tool.description,
input_schema: {
type: "object",
properties: (tool.parameters as any).properties || {},
required: (tool.parameters as any).required || [],
},
}));
}
function mapStopReason(reason: string): StopReason {
switch (reason) {
case "end_turn":
case "pause_turn":
case "stop_sequence":
return "stop";
case "max_tokens":
return "length";
case "tool_use":
return "toolUse";
default:
return "error";
}
}
function streamCustomAnthropic(
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const stream = createAssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey ?? "";
const isOAuth = isOAuthToken(apiKey);
// Configure client based on auth type
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14", "interleaved-thinking-2025-05-14"];
const clientOptions: any = {
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
};
if (isOAuth) {
clientOptions.apiKey = null;
clientOptions.authToken = apiKey;
clientOptions.defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
"user-agent": "claude-cli/2.1.2 (external, cli)",
"x-app": "cli",
};
} else {
clientOptions.apiKey = apiKey;
clientOptions.defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": betaFeatures.join(","),
};
}
const client = new Anthropic(clientOptions);
// Build request params
const params: MessageCreateParamsStreaming = {
model: model.id,
messages: convertMessages(context.messages, isOAuth, context.tools),
max_tokens: options?.maxTokens || Math.floor(model.maxTokens / 3),
stream: true,
};
// System prompt with Claude Code identity for OAuth
if (isOAuth) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: { type: "ephemeral" },
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
cache_control: { type: "ephemeral" },
});
}
} else if (context.systemPrompt) {
params.system = [
{
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
cache_control: { type: "ephemeral" },
},
];
}
if (context.tools) {
params.tools = convertTools(context.tools, isOAuth);
}
// Handle thinking/reasoning
if (options?.reasoning && model.reasoning) {
const defaultBudgets: Record<string, number> = {
minimal: 1024,
low: 4096,
medium: 10240,
high: 20480,
};
const customBudget = options.thinkingBudgets?.[options.reasoning as keyof typeof options.thinkingBudgets];
params.thinking = {
type: "enabled",
budget_tokens: customBudget ?? defaultBudgets[options.reasoning] ?? 10240,
};
}
const anthropicStream = client.messages.stream({ ...params }, { signal: options?.signal });
stream.push({ type: "start", partial: output });
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
const blocks = output.content as Block[];
for await (const event of anthropicStream) {
if (event.type === "message_start") {
output.usage.input = event.message.usage.input_tokens || 0;
output.usage.output = event.message.usage.output_tokens || 0;
output.usage.cacheRead = (event.message.usage as any).cache_read_input_tokens || 0;
output.usage.cacheWrite = (event.message.usage as any).cache_creation_input_tokens || 0;
output.usage.totalTokens =
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
calculateCost(model, output.usage);
} else if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
output.content.push({ type: "text", text: "", index: event.index } as any);
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "thinking") {
output.content.push({
type: "thinking",
thinking: "",
thinkingSignature: "",
index: event.index,
} as any);
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "tool_use") {
output.content.push({
type: "toolCall",
id: event.content_block.id,
name: isOAuth
? fromClaudeCodeName(event.content_block.name, context.tools)
: event.content_block.name,
arguments: {},
partialJson: "",
index: event.index,
} as any);
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
}
} else if (event.type === "content_block_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (!block) continue;
if (event.delta.type === "text_delta" && block.type === "text") {
block.text += event.delta.text;
stream.push({ type: "text_delta", contentIndex: index, delta: event.delta.text, partial: output });
} else if (event.delta.type === "thinking_delta" && block.type === "thinking") {
block.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
contentIndex: index,
delta: event.delta.thinking,
partial: output,
});
} else if (event.delta.type === "input_json_delta" && block.type === "toolCall") {
(block as any).partialJson += event.delta.partial_json;
try {
block.arguments = JSON.parse((block as any).partialJson);
} catch {}
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: event.delta.partial_json,
partial: output,
});
} else if (event.delta.type === "signature_delta" && block.type === "thinking") {
block.thinkingSignature = (block.thinkingSignature || "") + (event.delta as any).signature;
}
} else if (event.type === "content_block_stop") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (!block) continue;
delete (block as any).index;
if (block.type === "text") {
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output });
} else if (block.type === "thinking") {
stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output });
} else if (block.type === "toolCall") {
try {
block.arguments = JSON.parse((block as any).partialJson);
} catch {}
delete (block as any).partialJson;
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output });
}
} else if (event.type === "message_delta") {
if ((event.delta as any).stop_reason) {
output.stopReason = mapStopReason((event.delta as any).stop_reason);
}
output.usage.input = (event.usage as any).input_tokens || 0;
output.usage.output = (event.usage as any).output_tokens || 0;
output.usage.cacheRead = (event.usage as any).cache_read_input_tokens || 0;
output.usage.cacheWrite = (event.usage as any).cache_creation_input_tokens || 0;
output.usage.totalTokens =
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
calculateCost(model, output.usage);
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
stream.push({ type: "done", reason: output.stopReason as "stop" | "length" | "toolUse", message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
}
// =============================================================================
// Extension Entry Point
// =============================================================================
export default function (pi: ExtensionAPI) {
pi.registerProvider("custom-anthropic", {
baseUrl: "https://api.anthropic.com",
apiKey: "CUSTOM_ANTHROPIC_API_KEY",
api: "custom-anthropic-api",
models: [
{
id: "claude-opus-4-5",
name: "Claude Opus 4.5 (Custom)",
reasoning: true,
input: ["text", "image"],
cost: { input: 5, output: 25, cacheRead: 0.5, cacheWrite: 6.25 },
contextWindow: 200000,
maxTokens: 64000,
},
{
id: "claude-sonnet-4-5",
name: "Claude Sonnet 4.5 (Custom)",
reasoning: true,
input: ["text", "image"],
cost: { input: 3, output: 15, cacheRead: 0.3, cacheWrite: 3.75 },
contextWindow: 200000,
maxTokens: 64000,
},
],
oauth: {
name: "Custom Anthropic (Claude Pro/Max)",
login: loginAnthropic,
refreshToken: refreshAnthropicToken,
getApiKey: (cred) => cred.access,
},
streamSimple: streamCustomAnthropic,
});
}

View file

@ -23,7 +23,7 @@ import type {
ThinkingLevel,
} from "@mariozechner/pi-agent-core";
import type { AssistantMessage, ImageContent, Message, Model, TextContent } from "@mariozechner/pi-ai";
import { isContextOverflow, modelsAreEqual, supportsXhigh } from "@mariozechner/pi-ai";
import { isContextOverflow, modelsAreEqual, resetApiProviders, supportsXhigh } from "@mariozechner/pi-ai";
import { getAuthPath } from "../config.js";
import { theme } from "../modes/interactive/theme/theme.js";
import { stripFrontmatter } from "../utils/frontmatter.js";
@ -1832,6 +1832,7 @@ export class AgentSession {
async reload(): Promise<void> {
const previousFlagValues = this._extensionRunner?.getFlagValues();
await this._extensionRunner?.emit({ type: "session_shutdown" });
resetApiProviders();
await this._resourceLoader.reload();
this._buildRuntime({
activeToolNames: this.getActiveToolNames(),

View file

@ -16,10 +16,13 @@ import type {
} from "@mariozechner/pi-agent-core";
import type {
Api,
AssistantMessageEventStream,
Context,
ImageContent,
Model,
OAuthCredentials,
OAuthLoginCallbacks,
SimpleStreamOptions,
TextContent,
ToolResultMessage,
} from "@mariozechner/pi-ai";
@ -872,6 +875,7 @@ export interface ExtensionAPI {
* If `models` is provided: replaces all existing models for this provider.
* If only `baseUrl` is provided: overrides the URL for existing models.
* If `oauth` is provided: registers OAuth provider for /login support.
* If `streamSimple` is provided: registers a custom API stream handler.
*
* @example
* // Register a new provider with custom models
@ -930,6 +934,8 @@ export interface ProviderConfig {
apiKey?: string;
/** API type. Required at provider or model level when defining models. */
api?: Api;
/** Optional streamSimple handler for custom APIs. */
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
/** Custom headers to include in requests. */
headers?: Record<string, string>;
/** If true, adds Authorization: Bearer header with the resolved API key. */

View file

@ -4,12 +4,16 @@
import {
type Api,
type AssistantMessageEventStream,
type Context,
getModels,
getProviders,
type KnownProvider,
type Model,
type OAuthProviderInterface,
registerApiProvider,
registerOAuthProvider,
type SimpleStreamOptions,
} from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
@ -45,17 +49,7 @@ const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResp
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("azure-openai-responses"),
Type.Literal("openai-codex-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
Type.Literal("bedrock-converse-stream"),
]),
),
api: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
@ -73,17 +67,7 @@ const ModelDefinitionSchema = Type.Object({
const ProviderConfigSchema = Type.Object({
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
apiKey: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("azure-openai-responses"),
Type.Literal("openai-codex-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
Type.Literal("bedrock-converse-stream"),
]),
),
api: Type.Optional(Type.String({ minLength: 1 })),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
@ -482,6 +466,18 @@ export class ModelRegistry {
registerOAuthProvider(oauthProvider);
}
if (config.streamSimple) {
if (!config.api) {
throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`);
}
const streamSimple = config.streamSimple;
registerApiProvider({
api: config.api,
stream: (model, context, options) => streamSimple(model, context, options as SimpleStreamOptions),
streamSimple,
});
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
@ -556,6 +552,7 @@ export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */