From 177c694406e0f722ca8121fd9abcd01475a57463 Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Sat, 24 Jan 2026 23:11:20 +0100 Subject: [PATCH] feat: custom provider support with streamSimple - Add resetApiProviders() to clear and re-register built-in providers - Add createAssistantMessageEventStream() factory for extensions - Add streamSimple support in ProviderConfig for custom API implementations - Call resetApiProviders() on /reload to clean up extension providers - Add custom-provider.md documentation - Add custom-provider.ts example with full Anthropic implementation - Update extensions.md with streamSimple config option --- packages/ai/CHANGELOG.md | 2 + packages/ai/src/index.ts | 1 + .../ai/src/providers/register-builtins.ts | 101 +-- packages/ai/src/utils/event-stream.ts | 5 + packages/coding-agent/CHANGELOG.md | 2 + packages/coding-agent/docs/custom-provider.md | 547 ++++++++++++++++ packages/coding-agent/docs/extensions.md | 3 + .../examples/extensions/custom-provider.ts | 601 ++++++++++++++++++ .../coding-agent/src/core/agent-session.ts | 3 +- .../coding-agent/src/core/extensions/types.ts | 6 + .../coding-agent/src/core/model-registry.ts | 41 +- 11 files changed, 1243 insertions(+), 69 deletions(-) create mode 100644 packages/coding-agent/docs/custom-provider.md create mode 100644 packages/coding-agent/examples/extensions/custom-provider.ts diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index 6aaa7bbb..17c1768c 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -5,6 +5,8 @@ ### Added - Added `azure-openai-responses` provider support for Azure OpenAI Responses API. ([#890](https://github.com/badlogic/pi-mono/pull/890) by [@markusylisiurunen](https://github.com/markusylisiurunen)) +- Added `createAssistantMessageEventStream()` factory function for use in extensions. +- Added `resetApiProviders()` to clear and re-register built-in API providers. ### Changed diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index d600021a..46e1c470 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -8,6 +8,7 @@ export * from "./providers/google-gemini-cli.js"; export * from "./providers/google-vertex.js"; export * from "./providers/openai-completions.js"; export * from "./providers/openai-responses.js"; +export * from "./providers/register-builtins.js"; export * from "./stream.js"; export * from "./types.js"; export * from "./utils/event-stream.js"; diff --git a/packages/ai/src/providers/register-builtins.ts b/packages/ai/src/providers/register-builtins.ts index 3ad51add..1645f21f 100644 --- a/packages/ai/src/providers/register-builtins.ts +++ b/packages/ai/src/providers/register-builtins.ts @@ -1,4 +1,4 @@ -import { registerApiProvider } from "../api-registry.js"; +import { clearApiProviders, registerApiProvider } from "../api-registry.js"; import { streamBedrock, streamSimpleBedrock } from "./amazon-bedrock.js"; import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js"; import { streamAzureOpenAIResponses, streamSimpleAzureOpenAIResponses } from "./azure-openai-responses.js"; @@ -9,56 +9,65 @@ import { streamOpenAICodexResponses, streamSimpleOpenAICodexResponses } from "./ import { streamOpenAICompletions, streamSimpleOpenAICompletions } from "./openai-completions.js"; import { streamOpenAIResponses, streamSimpleOpenAIResponses } from "./openai-responses.js"; -registerApiProvider({ - api: "anthropic-messages", - stream: streamAnthropic, - streamSimple: streamSimpleAnthropic, -}); +export function registerBuiltInApiProviders(): void { + registerApiProvider({ + api: "anthropic-messages", + stream: streamAnthropic, + streamSimple: streamSimpleAnthropic, + }); -registerApiProvider({ - api: "openai-completions", - stream: streamOpenAICompletions, - streamSimple: streamSimpleOpenAICompletions, -}); + registerApiProvider({ + api: "openai-completions", + stream: streamOpenAICompletions, + streamSimple: streamSimpleOpenAICompletions, + }); -registerApiProvider({ - api: "openai-responses", - stream: streamOpenAIResponses, - streamSimple: streamSimpleOpenAIResponses, -}); + registerApiProvider({ + api: "openai-responses", + stream: streamOpenAIResponses, + streamSimple: streamSimpleOpenAIResponses, + }); -registerApiProvider({ - api: "azure-openai-responses", - stream: streamAzureOpenAIResponses, - streamSimple: streamSimpleAzureOpenAIResponses, -}); + registerApiProvider({ + api: "azure-openai-responses", + stream: streamAzureOpenAIResponses, + streamSimple: streamSimpleAzureOpenAIResponses, + }); -registerApiProvider({ - api: "openai-codex-responses", - stream: streamOpenAICodexResponses, - streamSimple: streamSimpleOpenAICodexResponses, -}); + registerApiProvider({ + api: "openai-codex-responses", + stream: streamOpenAICodexResponses, + streamSimple: streamSimpleOpenAICodexResponses, + }); -registerApiProvider({ - api: "google-generative-ai", - stream: streamGoogle, - streamSimple: streamSimpleGoogle, -}); + registerApiProvider({ + api: "google-generative-ai", + stream: streamGoogle, + streamSimple: streamSimpleGoogle, + }); -registerApiProvider({ - api: "google-gemini-cli", - stream: streamGoogleGeminiCli, - streamSimple: streamSimpleGoogleGeminiCli, -}); + registerApiProvider({ + api: "google-gemini-cli", + stream: streamGoogleGeminiCli, + streamSimple: streamSimpleGoogleGeminiCli, + }); -registerApiProvider({ - api: "google-vertex", - stream: streamGoogleVertex, - streamSimple: streamSimpleGoogleVertex, -}); + registerApiProvider({ + api: "google-vertex", + stream: streamGoogleVertex, + streamSimple: streamSimpleGoogleVertex, + }); -registerApiProvider({ - api: "bedrock-converse-stream", - stream: streamBedrock, - streamSimple: streamSimpleBedrock, -}); + registerApiProvider({ + api: "bedrock-converse-stream", + stream: streamBedrock, + streamSimple: streamSimpleBedrock, + }); +} + +export function resetApiProviders(): void { + clearApiProviders(); + registerBuiltInApiProviders(); +} + +registerBuiltInApiProviders(); diff --git a/packages/ai/src/utils/event-stream.ts b/packages/ai/src/utils/event-stream.ts index 74947477..f4a7ceba 100644 --- a/packages/ai/src/utils/event-stream.ts +++ b/packages/ai/src/utils/event-stream.ts @@ -80,3 +80,8 @@ export class AssistantMessageEventStream extends EventStream` but doesn't use a standard API, set `authHeader: true`: + +```typescript +pi.registerProvider("custom-api", { + baseUrl: "https://api.example.com", + apiKey: "MY_API_KEY", + authHeader: true, // adds Authorization: Bearer header + api: "openai-completions", + models: [...] +}); +``` + +## OAuth Support + +Add OAuth/SSO authentication that integrates with `/login`: + +```typescript +import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai"; + +pi.registerProvider("corporate-ai", { + baseUrl: "https://ai.corp.com/v1", + api: "openai-responses", + models: [ + { + id: "corp-claude", + name: "Corporate Claude", + reasoning: true, + input: ["text", "image"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 16384 + } + ], + oauth: { + name: "Corporate AI (SSO)", + + async login(callbacks: OAuthLoginCallbacks): Promise { + // Option 1: Browser-based OAuth + callbacks.onAuth({ url: "https://sso.corp.com/authorize?..." }); + + // Option 2: Device code flow + callbacks.onDeviceCode({ + userCode: "ABCD-1234", + verificationUri: "https://sso.corp.com/device" + }); + + // Option 3: Prompt for token/code + const code = await callbacks.onPrompt({ message: "Enter SSO code:" }); + + // Exchange for tokens (your implementation) + const tokens = await exchangeCodeForTokens(code); + + return { + refresh: tokens.refreshToken, + access: tokens.accessToken, + expires: Date.now() + tokens.expiresIn * 1000 + }; + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + const tokens = await refreshAccessToken(credentials.refresh); + return { + refresh: tokens.refreshToken ?? credentials.refresh, + access: tokens.accessToken, + expires: Date.now() + tokens.expiresIn * 1000 + }; + }, + + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, + + // Optional: modify models based on user's subscription + modifyModels(models, credentials) { + // e.g., update baseUrl based on user's region + const region = decodeRegionFromToken(credentials.access); + return models.map(m => ({ + ...m, + baseUrl: `https://${region}.ai.corp.com/v1` + })); + } + } +}); +``` + +After registration, users can authenticate via `/login corporate-ai`. + +### OAuthLoginCallbacks + +The `callbacks` object provides three ways to authenticate: + +```typescript +interface OAuthLoginCallbacks { + // Open URL in browser (for OAuth redirects) + onAuth(params: { url: string }): void; + + // Show device code (for device authorization flow) + onDeviceCode(params: { userCode: string; verificationUri: string }): void; + + // Prompt user for input (for manual token entry) + onPrompt(params: { message: string }): Promise; +} +``` + +### OAuthCredentials + +Credentials are persisted in `~/.pi/agent/auth.json`: + +```typescript +interface OAuthCredentials { + refresh: string; // Refresh token (for refreshToken()) + access: string; // Access token (returned by getApiKey()) + expires: number; // Expiration timestamp in milliseconds +} +``` + +## Custom Streaming API + +For providers with non-standard APIs, implement `streamSimple`: + +```typescript +import type { + AssistantMessageEventStream, + Context, + Model, + SimpleStreamOptions, + Api +} from "@mariozechner/pi-ai"; +import { createAssistantMessageEventStream } from "@mariozechner/pi-ai"; + +pi.registerProvider("custom-llm", { + baseUrl: "https://api.custom-llm.com", + apiKey: "CUSTOM_LLM_KEY", + api: "custom-llm-api", // your custom API identifier + models: [ + { + id: "custom-model", + name: "Custom Model", + reasoning: false, + input: ["text"], + cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 32000, + maxTokens: 4096 + } + ], + + streamSimple( + model: Model, + context: Context, + options?: SimpleStreamOptions + ): AssistantMessageEventStream { + return createAssistantMessageEventStream(async function* (signal) { + // Convert context to your API format + const messages = context.messages.map(m => ({ + role: m.role, + content: typeof m.content === "string" + ? m.content + : m.content.filter(c => c.type === "text").map(c => c.text).join("") + })); + + // Make streaming request + const response = await fetch(`${model.baseUrl}/chat`, { + method: "POST", + headers: { + "Authorization": `Bearer ${options?.apiKey}`, + "Content-Type": "application/json" + }, + body: JSON.stringify({ + model: model.id, + messages, + stream: true + }), + signal + }); + + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + + // Yield start event + yield { type: "start" }; + + // Parse SSE stream + const reader = response.body!.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + let contentIndex = 0; + let textStarted = false; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + if (!line.startsWith("data: ")) continue; + const data = line.slice(6); + if (data === "[DONE]") continue; + + const chunk = JSON.parse(data); + const delta = chunk.choices?.[0]?.delta?.content; + + if (delta) { + if (!textStarted) { + yield { type: "text_start", contentIndex }; + textStarted = true; + } + yield { type: "text_delta", contentIndex, delta }; + } + } + } + + if (textStarted) { + yield { type: "text_end", contentIndex }; + } + + // Yield usage if available + yield { + type: "usage", + usage: { + input: 0, // fill from response if available + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } + } + }; + + // Yield done + yield { type: "done", reason: "stop" }; + }); + } +}); +``` + +### Event Types + +Your generator must yield events in this order: + +1. `{ type: "start" }` - Stream started +2. Content events (repeatable, in order): + - `{ type: "text_start", contentIndex }` - Text block started + - `{ type: "text_delta", contentIndex, delta }` - Text chunk + - `{ type: "text_end", contentIndex }` - Text block ended + - `{ type: "thinking_start", contentIndex }` - Thinking block started + - `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk + - `{ type: "thinking_end", contentIndex }` - Thinking block ended + - `{ type: "toolcall_start", contentIndex }` - Tool call started + - `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk + - `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended +3. `{ type: "usage", usage }` - Token usage (optional but recommended) +4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended + +### Reasoning Support + +For models with extended thinking, yield thinking events: + +```typescript +if (chunk.thinking) { + if (!thinkingStarted) { + yield { type: "thinking_start", contentIndex: thinkingIndex }; + thinkingStarted = true; + } + yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking }; +} +``` + +### Tool Calls + +For function calling support, yield tool call events: + +```typescript +if (chunk.tool_calls) { + for (const tc of chunk.tool_calls) { + if (tc.index !== currentToolIndex) { + if (currentToolIndex >= 0) { + yield { + type: "toolcall_end", + contentIndex: currentToolIndex, + toolCall: { + type: "toolCall", + id: currentToolId, + name: currentToolName, + arguments: JSON.parse(currentToolArgs) + } + }; + } + currentToolIndex = tc.index; + currentToolId = tc.id; + currentToolName = tc.function.name; + currentToolArgs = ""; + yield { type: "toolcall_start", contentIndex: tc.index }; + } + if (tc.function.arguments) { + currentToolArgs += tc.function.arguments; + yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments }; + } + } +} +``` + +## Config Reference + +```typescript +interface ProviderConfig { + /** API endpoint URL. Required when defining models. */ + baseUrl?: string; + + /** API key or environment variable name. Required when defining models (unless oauth). */ + apiKey?: string; + + /** API type for streaming. Required at provider or model level when defining models. */ + api?: Api; + + /** Custom streaming implementation for non-standard APIs. */ + streamSimple?: ( + model: Model, + context: Context, + options?: SimpleStreamOptions + ) => AssistantMessageEventStream; + + /** Custom headers to include in requests. Values can be env var names. */ + headers?: Record; + + /** If true, adds Authorization: Bearer header with the resolved API key. */ + authHeader?: boolean; + + /** Models to register. If provided, replaces all existing models for this provider. */ + models?: ProviderModelConfig[]; + + /** OAuth provider for /login support. */ + oauth?: { + name: string; + login(callbacks: OAuthLoginCallbacks): Promise; + refreshToken(credentials: OAuthCredentials): Promise; + getApiKey(credentials: OAuthCredentials): string; + modifyModels?(models: Model[], credentials: OAuthCredentials): Model[]; + }; +} +``` + +## Model Definition Reference + +```typescript +interface ProviderModelConfig { + /** Model ID (e.g., "claude-sonnet-4-20250514"). */ + id: string; + + /** Display name (e.g., "Claude 4 Sonnet"). */ + name: string; + + /** API type override for this specific model. */ + api?: Api; + + /** Whether the model supports extended thinking. */ + reasoning: boolean; + + /** Supported input types. */ + input: ("text" | "image")[]; + + /** Cost per million tokens (for usage tracking). */ + cost: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + }; + + /** Maximum context window size in tokens. */ + contextWindow: number; + + /** Maximum output tokens. */ + maxTokens: number; + + /** Custom headers for this specific model. */ + headers?: Record; + + /** OpenAI compatibility settings for openai-completions API. */ + compat?: { + supportsStore?: boolean; + supportsDeveloperRole?: boolean; + supportsReasoningEffort?: boolean; + supportsUsageInStreaming?: boolean; + maxTokensField?: "max_completion_tokens" | "max_tokens"; + requiresToolResultName?: boolean; + requiresAssistantAfterToolResult?: boolean; + requiresThinkingAsText?: boolean; + requiresMistralToolIds?: boolean; + thinkingFormat?: "openai" | "zai"; + }; +} +``` diff --git a/packages/coding-agent/docs/extensions.md b/packages/coding-agent/docs/extensions.md index 879d461c..098004fb 100644 --- a/packages/coding-agent/docs/extensions.md +++ b/packages/coding-agent/docs/extensions.md @@ -1206,6 +1206,9 @@ pi.registerProvider("corporate-ai", { - `authHeader` - If true, adds `Authorization: Bearer` header automatically. - `models` - Array of model definitions. If provided, replaces all existing models for this provider. - `oauth` - OAuth provider config for `/login` support. When provided, the provider appears in the login menu. +- `streamSimple` - Custom streaming implementation for non-standard APIs. + +See [custom-provider.md](custom-provider.md) for advanced topics: custom streaming APIs, OAuth details, model definition reference. ## State Management diff --git a/packages/coding-agent/examples/extensions/custom-provider.ts b/packages/coding-agent/examples/extensions/custom-provider.ts new file mode 100644 index 00000000..19019d5e --- /dev/null +++ b/packages/coding-agent/examples/extensions/custom-provider.ts @@ -0,0 +1,601 @@ +/** + * Custom Provider Example + * + * Demonstrates registering a custom provider with: + * - Custom API identifier ("custom-anthropic-api") + * - Custom streamSimple implementation + * - OAuth support for /login + * - API key support via environment variable + * - Two model definitions + * + * Usage: + * # With OAuth (run /login custom-anthropic first) + * pi -e ./custom-provider.ts + * + * # With API key + * CUSTOM_ANTHROPIC_API_KEY=sk-ant-... pi -e ./custom-provider.ts + * + * Then use /model to select custom-anthropic/claude-sonnet-4-5 + */ + +import Anthropic from "@anthropic-ai/sdk"; +import type { ContentBlockParam, MessageCreateParamsStreaming } from "@anthropic-ai/sdk/resources/messages.js"; +import { + type Api, + type AssistantMessage, + type AssistantMessageEventStream, + type Context, + calculateCost, + createAssistantMessageEventStream, + type ImageContent, + type Message, + type Model, + type OAuthCredentials, + type OAuthLoginCallbacks, + type SimpleStreamOptions, + type StopReason, + type TextContent, + type ThinkingContent, + type Tool, + type ToolCall, + type ToolResultMessage, +} from "@mariozechner/pi-ai"; +import type { ExtensionAPI } from "@mariozechner/pi-coding-agent"; + +// ============================================================================= +// OAuth Implementation (copied from packages/ai/src/utils/oauth/anthropic.ts) +// ============================================================================= + +const decode = (s: string) => atob(s); +const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl"); +const AUTHORIZE_URL = "https://claude.ai/oauth/authorize"; +const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token"; +const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback"; +const SCOPES = "org:create_api_key user:profile user:inference"; + +async function generatePKCE(): Promise<{ verifier: string; challenge: string }> { + const array = new Uint8Array(32); + crypto.getRandomValues(array); + const verifier = btoa(String.fromCharCode(...array)) + .replace(/\+/g, "-") + .replace(/\//g, "_") + .replace(/=+$/, ""); + + const encoder = new TextEncoder(); + const data = encoder.encode(verifier); + const hash = await crypto.subtle.digest("SHA-256", data); + const challenge = btoa(String.fromCharCode(...new Uint8Array(hash))) + .replace(/\+/g, "-") + .replace(/\//g, "_") + .replace(/=+$/, ""); + + return { verifier, challenge }; +} + +async function loginAnthropic(callbacks: OAuthLoginCallbacks): Promise { + const { verifier, challenge } = await generatePKCE(); + + const authParams = new URLSearchParams({ + code: "true", + client_id: CLIENT_ID, + response_type: "code", + redirect_uri: REDIRECT_URI, + scope: SCOPES, + code_challenge: challenge, + code_challenge_method: "S256", + state: verifier, + }); + + callbacks.onAuth({ url: `${AUTHORIZE_URL}?${authParams.toString()}` }); + + const authCode = await callbacks.onPrompt({ message: "Paste the authorization code:" }); + const [code, state] = authCode.split("#"); + + const tokenResponse = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + grant_type: "authorization_code", + client_id: CLIENT_ID, + code, + state, + redirect_uri: REDIRECT_URI, + code_verifier: verifier, + }), + }); + + if (!tokenResponse.ok) { + throw new Error(`Token exchange failed: ${await tokenResponse.text()}`); + } + + const data = (await tokenResponse.json()) as { + access_token: string; + refresh_token: string; + expires_in: number; + }; + + return { + refresh: data.refresh_token, + access: data.access_token, + expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000, + }; +} + +async function refreshAnthropicToken(credentials: OAuthCredentials): Promise { + const response = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + grant_type: "refresh_token", + client_id: CLIENT_ID, + refresh_token: credentials.refresh, + }), + }); + + if (!response.ok) { + throw new Error(`Token refresh failed: ${await response.text()}`); + } + + const data = (await response.json()) as { + access_token: string; + refresh_token: string; + expires_in: number; + }; + + return { + refresh: data.refresh_token, + access: data.access_token, + expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000, + }; +} + +// ============================================================================= +// Streaming Implementation (simplified from packages/ai/src/providers/anthropic.ts) +// ============================================================================= + +// Claude Code tool names for OAuth stealth mode +const claudeCodeTools = [ + "Read", + "Write", + "Edit", + "Bash", + "Grep", + "Glob", + "AskUserQuestion", + "TodoWrite", + "WebFetch", + "WebSearch", +]; +const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t])); +const toClaudeCodeName = (name: string) => ccToolLookup.get(name.toLowerCase()) ?? name; +const fromClaudeCodeName = (name: string, tools?: Tool[]) => { + const lowerName = name.toLowerCase(); + const matched = tools?.find((t) => t.name.toLowerCase() === lowerName); + return matched?.name ?? name; +}; + +function isOAuthToken(apiKey: string): boolean { + return apiKey.includes("sk-ant-oat"); +} + +function sanitizeSurrogates(text: string): string { + return text.replace(/[\uD800-\uDFFF]/g, "\uFFFD"); +} + +function convertContentBlocks( + content: (TextContent | ImageContent)[], +): string | Array<{ type: "text"; text: string } | { type: "image"; source: any }> { + const hasImages = content.some((c) => c.type === "image"); + if (!hasImages) { + return sanitizeSurrogates(content.map((c) => (c as TextContent).text).join("\n")); + } + + const blocks = content.map((block) => { + if (block.type === "text") { + return { type: "text" as const, text: sanitizeSurrogates(block.text) }; + } + return { + type: "image" as const, + source: { + type: "base64" as const, + media_type: block.mimeType, + data: block.data, + }, + }; + }); + + if (!blocks.some((b) => b.type === "text")) { + blocks.unshift({ type: "text" as const, text: "(see attached image)" }); + } + + return blocks; +} + +function convertMessages(messages: Message[], isOAuth: boolean, _tools?: Tool[]): any[] { + const params: any[] = []; + + for (let i = 0; i < messages.length; i++) { + const msg = messages[i]; + + if (msg.role === "user") { + if (typeof msg.content === "string") { + if (msg.content.trim()) { + params.push({ role: "user", content: sanitizeSurrogates(msg.content) }); + } + } else { + const blocks: ContentBlockParam[] = msg.content.map((item) => + item.type === "text" + ? { type: "text" as const, text: sanitizeSurrogates(item.text) } + : { + type: "image" as const, + source: { type: "base64" as const, media_type: item.mimeType as any, data: item.data }, + }, + ); + if (blocks.length > 0) { + params.push({ role: "user", content: blocks }); + } + } + } else if (msg.role === "assistant") { + const blocks: ContentBlockParam[] = []; + for (const block of msg.content) { + if (block.type === "text" && block.text.trim()) { + blocks.push({ type: "text", text: sanitizeSurrogates(block.text) }); + } else if (block.type === "thinking" && block.thinking.trim()) { + if ((block as ThinkingContent).thinkingSignature) { + blocks.push({ + type: "thinking" as any, + thinking: sanitizeSurrogates(block.thinking), + signature: (block as ThinkingContent).thinkingSignature!, + }); + } else { + blocks.push({ type: "text", text: sanitizeSurrogates(block.thinking) }); + } + } else if (block.type === "toolCall") { + blocks.push({ + type: "tool_use", + id: block.id, + name: isOAuth ? toClaudeCodeName(block.name) : block.name, + input: block.arguments, + }); + } + } + if (blocks.length > 0) { + params.push({ role: "assistant", content: blocks }); + } + } else if (msg.role === "toolResult") { + const toolResults: any[] = []; + toolResults.push({ + type: "tool_result", + tool_use_id: msg.toolCallId, + content: convertContentBlocks(msg.content), + is_error: msg.isError, + }); + + let j = i + 1; + while (j < messages.length && messages[j].role === "toolResult") { + const nextMsg = messages[j] as ToolResultMessage; + toolResults.push({ + type: "tool_result", + tool_use_id: nextMsg.toolCallId, + content: convertContentBlocks(nextMsg.content), + is_error: nextMsg.isError, + }); + j++; + } + i = j - 1; + params.push({ role: "user", content: toolResults }); + } + } + + // Add cache control to last user message + if (params.length > 0) { + const last = params[params.length - 1]; + if (last.role === "user" && Array.isArray(last.content)) { + const lastBlock = last.content[last.content.length - 1]; + if (lastBlock) { + lastBlock.cache_control = { type: "ephemeral" }; + } + } + } + + return params; +} + +function convertTools(tools: Tool[], isOAuth: boolean): any[] { + return tools.map((tool) => ({ + name: isOAuth ? toClaudeCodeName(tool.name) : tool.name, + description: tool.description, + input_schema: { + type: "object", + properties: (tool.parameters as any).properties || {}, + required: (tool.parameters as any).required || [], + }, + })); +} + +function mapStopReason(reason: string): StopReason { + switch (reason) { + case "end_turn": + case "pause_turn": + case "stop_sequence": + return "stop"; + case "max_tokens": + return "length"; + case "tool_use": + return "toolUse"; + default: + return "error"; + } +} + +function streamCustomAnthropic( + model: Model, + context: Context, + options?: SimpleStreamOptions, +): AssistantMessageEventStream { + const stream = createAssistantMessageEventStream(); + + (async () => { + const output: AssistantMessage = { + role: "assistant", + content: [], + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + + try { + const apiKey = options?.apiKey ?? ""; + const isOAuth = isOAuthToken(apiKey); + + // Configure client based on auth type + const betaFeatures = ["fine-grained-tool-streaming-2025-05-14", "interleaved-thinking-2025-05-14"]; + const clientOptions: any = { + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + }; + + if (isOAuth) { + clientOptions.apiKey = null; + clientOptions.authToken = apiKey; + clientOptions.defaultHeaders = { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`, + "user-agent": "claude-cli/2.1.2 (external, cli)", + "x-app": "cli", + }; + } else { + clientOptions.apiKey = apiKey; + clientOptions.defaultHeaders = { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "anthropic-beta": betaFeatures.join(","), + }; + } + + const client = new Anthropic(clientOptions); + + // Build request params + const params: MessageCreateParamsStreaming = { + model: model.id, + messages: convertMessages(context.messages, isOAuth, context.tools), + max_tokens: options?.maxTokens || Math.floor(model.maxTokens / 3), + stream: true, + }; + + // System prompt with Claude Code identity for OAuth + if (isOAuth) { + params.system = [ + { + type: "text", + text: "You are Claude Code, Anthropic's official CLI for Claude.", + cache_control: { type: "ephemeral" }, + }, + ]; + if (context.systemPrompt) { + params.system.push({ + type: "text", + text: sanitizeSurrogates(context.systemPrompt), + cache_control: { type: "ephemeral" }, + }); + } + } else if (context.systemPrompt) { + params.system = [ + { + type: "text", + text: sanitizeSurrogates(context.systemPrompt), + cache_control: { type: "ephemeral" }, + }, + ]; + } + + if (context.tools) { + params.tools = convertTools(context.tools, isOAuth); + } + + // Handle thinking/reasoning + if (options?.reasoning && model.reasoning) { + const defaultBudgets: Record = { + minimal: 1024, + low: 4096, + medium: 10240, + high: 20480, + }; + const customBudget = options.thinkingBudgets?.[options.reasoning as keyof typeof options.thinkingBudgets]; + params.thinking = { + type: "enabled", + budget_tokens: customBudget ?? defaultBudgets[options.reasoning] ?? 10240, + }; + } + + const anthropicStream = client.messages.stream({ ...params }, { signal: options?.signal }); + stream.push({ type: "start", partial: output }); + + type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number }; + const blocks = output.content as Block[]; + + for await (const event of anthropicStream) { + if (event.type === "message_start") { + output.usage.input = event.message.usage.input_tokens || 0; + output.usage.output = event.message.usage.output_tokens || 0; + output.usage.cacheRead = (event.message.usage as any).cache_read_input_tokens || 0; + output.usage.cacheWrite = (event.message.usage as any).cache_creation_input_tokens || 0; + output.usage.totalTokens = + output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite; + calculateCost(model, output.usage); + } else if (event.type === "content_block_start") { + if (event.content_block.type === "text") { + output.content.push({ type: "text", text: "", index: event.index } as any); + stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output }); + } else if (event.content_block.type === "thinking") { + output.content.push({ + type: "thinking", + thinking: "", + thinkingSignature: "", + index: event.index, + } as any); + stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output }); + } else if (event.content_block.type === "tool_use") { + output.content.push({ + type: "toolCall", + id: event.content_block.id, + name: isOAuth + ? fromClaudeCodeName(event.content_block.name, context.tools) + : event.content_block.name, + arguments: {}, + partialJson: "", + index: event.index, + } as any); + stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output }); + } + } else if (event.type === "content_block_delta") { + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (!block) continue; + + if (event.delta.type === "text_delta" && block.type === "text") { + block.text += event.delta.text; + stream.push({ type: "text_delta", contentIndex: index, delta: event.delta.text, partial: output }); + } else if (event.delta.type === "thinking_delta" && block.type === "thinking") { + block.thinking += event.delta.thinking; + stream.push({ + type: "thinking_delta", + contentIndex: index, + delta: event.delta.thinking, + partial: output, + }); + } else if (event.delta.type === "input_json_delta" && block.type === "toolCall") { + (block as any).partialJson += event.delta.partial_json; + try { + block.arguments = JSON.parse((block as any).partialJson); + } catch {} + stream.push({ + type: "toolcall_delta", + contentIndex: index, + delta: event.delta.partial_json, + partial: output, + }); + } else if (event.delta.type === "signature_delta" && block.type === "thinking") { + block.thinkingSignature = (block.thinkingSignature || "") + (event.delta as any).signature; + } + } else if (event.type === "content_block_stop") { + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (!block) continue; + + delete (block as any).index; + if (block.type === "text") { + stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output }); + } else if (block.type === "thinking") { + stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output }); + } else if (block.type === "toolCall") { + try { + block.arguments = JSON.parse((block as any).partialJson); + } catch {} + delete (block as any).partialJson; + stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output }); + } + } else if (event.type === "message_delta") { + if ((event.delta as any).stop_reason) { + output.stopReason = mapStopReason((event.delta as any).stop_reason); + } + output.usage.input = (event.usage as any).input_tokens || 0; + output.usage.output = (event.usage as any).output_tokens || 0; + output.usage.cacheRead = (event.usage as any).cache_read_input_tokens || 0; + output.usage.cacheWrite = (event.usage as any).cache_creation_input_tokens || 0; + output.usage.totalTokens = + output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite; + calculateCost(model, output.usage); + } + } + + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } + + stream.push({ type: "done", reason: output.stopReason as "stop" | "length" | "toolUse", message: output }); + stream.end(); + } catch (error) { + for (const block of output.content) delete (block as any).index; + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason, error: output }); + stream.end(); + } + })(); + + return stream; +} + +// ============================================================================= +// Extension Entry Point +// ============================================================================= + +export default function (pi: ExtensionAPI) { + pi.registerProvider("custom-anthropic", { + baseUrl: "https://api.anthropic.com", + apiKey: "CUSTOM_ANTHROPIC_API_KEY", + api: "custom-anthropic-api", + + models: [ + { + id: "claude-opus-4-5", + name: "Claude Opus 4.5 (Custom)", + reasoning: true, + input: ["text", "image"], + cost: { input: 5, output: 25, cacheRead: 0.5, cacheWrite: 6.25 }, + contextWindow: 200000, + maxTokens: 64000, + }, + { + id: "claude-sonnet-4-5", + name: "Claude Sonnet 4.5 (Custom)", + reasoning: true, + input: ["text", "image"], + cost: { input: 3, output: 15, cacheRead: 0.3, cacheWrite: 3.75 }, + contextWindow: 200000, + maxTokens: 64000, + }, + ], + + oauth: { + name: "Custom Anthropic (Claude Pro/Max)", + login: loginAnthropic, + refreshToken: refreshAnthropicToken, + getApiKey: (cred) => cred.access, + }, + + streamSimple: streamCustomAnthropic, + }); +} diff --git a/packages/coding-agent/src/core/agent-session.ts b/packages/coding-agent/src/core/agent-session.ts index 27054553..d0e9aaef 100644 --- a/packages/coding-agent/src/core/agent-session.ts +++ b/packages/coding-agent/src/core/agent-session.ts @@ -23,7 +23,7 @@ import type { ThinkingLevel, } from "@mariozechner/pi-agent-core"; import type { AssistantMessage, ImageContent, Message, Model, TextContent } from "@mariozechner/pi-ai"; -import { isContextOverflow, modelsAreEqual, supportsXhigh } from "@mariozechner/pi-ai"; +import { isContextOverflow, modelsAreEqual, resetApiProviders, supportsXhigh } from "@mariozechner/pi-ai"; import { getAuthPath } from "../config.js"; import { theme } from "../modes/interactive/theme/theme.js"; import { stripFrontmatter } from "../utils/frontmatter.js"; @@ -1832,6 +1832,7 @@ export class AgentSession { async reload(): Promise { const previousFlagValues = this._extensionRunner?.getFlagValues(); await this._extensionRunner?.emit({ type: "session_shutdown" }); + resetApiProviders(); await this._resourceLoader.reload(); this._buildRuntime({ activeToolNames: this.getActiveToolNames(), diff --git a/packages/coding-agent/src/core/extensions/types.ts b/packages/coding-agent/src/core/extensions/types.ts index 08316693..47e6051a 100644 --- a/packages/coding-agent/src/core/extensions/types.ts +++ b/packages/coding-agent/src/core/extensions/types.ts @@ -16,10 +16,13 @@ import type { } from "@mariozechner/pi-agent-core"; import type { Api, + AssistantMessageEventStream, + Context, ImageContent, Model, OAuthCredentials, OAuthLoginCallbacks, + SimpleStreamOptions, TextContent, ToolResultMessage, } from "@mariozechner/pi-ai"; @@ -872,6 +875,7 @@ export interface ExtensionAPI { * If `models` is provided: replaces all existing models for this provider. * If only `baseUrl` is provided: overrides the URL for existing models. * If `oauth` is provided: registers OAuth provider for /login support. + * If `streamSimple` is provided: registers a custom API stream handler. * * @example * // Register a new provider with custom models @@ -930,6 +934,8 @@ export interface ProviderConfig { apiKey?: string; /** API type. Required at provider or model level when defining models. */ api?: Api; + /** Optional streamSimple handler for custom APIs. */ + streamSimple?: (model: Model, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream; /** Custom headers to include in requests. */ headers?: Record; /** If true, adds Authorization: Bearer header with the resolved API key. */ diff --git a/packages/coding-agent/src/core/model-registry.ts b/packages/coding-agent/src/core/model-registry.ts index 77c8c488..61c67c6e 100644 --- a/packages/coding-agent/src/core/model-registry.ts +++ b/packages/coding-agent/src/core/model-registry.ts @@ -4,12 +4,16 @@ import { type Api, + type AssistantMessageEventStream, + type Context, getModels, getProviders, type KnownProvider, type Model, type OAuthProviderInterface, + registerApiProvider, registerOAuthProvider, + type SimpleStreamOptions, } from "@mariozechner/pi-ai"; import { type Static, Type } from "@sinclair/typebox"; import AjvModule from "ajv"; @@ -45,17 +49,7 @@ const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResp const ModelDefinitionSchema = Type.Object({ id: Type.String({ minLength: 1 }), name: Type.String({ minLength: 1 }), - api: Type.Optional( - Type.Union([ - Type.Literal("openai-completions"), - Type.Literal("openai-responses"), - Type.Literal("azure-openai-responses"), - Type.Literal("openai-codex-responses"), - Type.Literal("anthropic-messages"), - Type.Literal("google-generative-ai"), - Type.Literal("bedrock-converse-stream"), - ]), - ), + api: Type.Optional(Type.String({ minLength: 1 })), reasoning: Type.Boolean(), input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])), cost: Type.Object({ @@ -73,17 +67,7 @@ const ModelDefinitionSchema = Type.Object({ const ProviderConfigSchema = Type.Object({ baseUrl: Type.Optional(Type.String({ minLength: 1 })), apiKey: Type.Optional(Type.String({ minLength: 1 })), - api: Type.Optional( - Type.Union([ - Type.Literal("openai-completions"), - Type.Literal("openai-responses"), - Type.Literal("azure-openai-responses"), - Type.Literal("openai-codex-responses"), - Type.Literal("anthropic-messages"), - Type.Literal("google-generative-ai"), - Type.Literal("bedrock-converse-stream"), - ]), - ), + api: Type.Optional(Type.String({ minLength: 1 })), headers: Type.Optional(Type.Record(Type.String(), Type.String())), authHeader: Type.Optional(Type.Boolean()), models: Type.Optional(Type.Array(ModelDefinitionSchema)), @@ -482,6 +466,18 @@ export class ModelRegistry { registerOAuthProvider(oauthProvider); } + if (config.streamSimple) { + if (!config.api) { + throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`); + } + const streamSimple = config.streamSimple; + registerApiProvider({ + api: config.api, + stream: (model, context, options) => streamSimple(model, context, options as SimpleStreamOptions), + streamSimple, + }); + } + // Store API key for auth resolution if (config.apiKey) { this.customProviderApiKeys.set(providerName, config.apiKey); @@ -556,6 +552,7 @@ export interface ProviderConfigInput { baseUrl?: string; apiKey?: string; api?: Api; + streamSimple?: (model: Model, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream; headers?: Record; authHeader?: boolean; /** OAuth provider for /login support */