mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 07:04:45 +00:00
feat: custom provider support with streamSimple
- Add resetApiProviders() to clear and re-register built-in providers - Add createAssistantMessageEventStream() factory for extensions - Add streamSimple support in ProviderConfig for custom API implementations - Call resetApiProviders() on /reload to clean up extension providers - Add custom-provider.md documentation - Add custom-provider.ts example with full Anthropic implementation - Update extensions.md with streamSimple config option
This commit is contained in:
parent
c06163bc59
commit
177c694406
11 changed files with 1243 additions and 69 deletions
|
|
@ -5,6 +5,8 @@
|
|||
### Added
|
||||
|
||||
- Added `azure-openai-responses` provider support for Azure OpenAI Responses API. ([#890](https://github.com/badlogic/pi-mono/pull/890) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
- Added `createAssistantMessageEventStream()` factory function for use in extensions.
|
||||
- Added `resetApiProviders()` to clear and re-register built-in API providers.
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ export * from "./providers/google-gemini-cli.js";
|
|||
export * from "./providers/google-vertex.js";
|
||||
export * from "./providers/openai-completions.js";
|
||||
export * from "./providers/openai-responses.js";
|
||||
export * from "./providers/register-builtins.js";
|
||||
export * from "./stream.js";
|
||||
export * from "./types.js";
|
||||
export * from "./utils/event-stream.js";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { registerApiProvider } from "../api-registry.js";
|
||||
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
|
||||
import { streamBedrock, streamSimpleBedrock } from "./amazon-bedrock.js";
|
||||
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
|
||||
import { streamAzureOpenAIResponses, streamSimpleAzureOpenAIResponses } from "./azure-openai-responses.js";
|
||||
|
|
@ -9,56 +9,65 @@ import { streamOpenAICodexResponses, streamSimpleOpenAICodexResponses } from "./
|
|||
import { streamOpenAICompletions, streamSimpleOpenAICompletions } from "./openai-completions.js";
|
||||
import { streamOpenAIResponses, streamSimpleOpenAIResponses } from "./openai-responses.js";
|
||||
|
||||
registerApiProvider({
|
||||
api: "anthropic-messages",
|
||||
stream: streamAnthropic,
|
||||
streamSimple: streamSimpleAnthropic,
|
||||
});
|
||||
export function registerBuiltInApiProviders(): void {
|
||||
registerApiProvider({
|
||||
api: "anthropic-messages",
|
||||
stream: streamAnthropic,
|
||||
streamSimple: streamSimpleAnthropic,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-completions",
|
||||
stream: streamOpenAICompletions,
|
||||
streamSimple: streamSimpleOpenAICompletions,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "openai-completions",
|
||||
stream: streamOpenAICompletions,
|
||||
streamSimple: streamSimpleOpenAICompletions,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-responses",
|
||||
stream: streamOpenAIResponses,
|
||||
streamSimple: streamSimpleOpenAIResponses,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "openai-responses",
|
||||
stream: streamOpenAIResponses,
|
||||
streamSimple: streamSimpleOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "azure-openai-responses",
|
||||
stream: streamAzureOpenAIResponses,
|
||||
streamSimple: streamSimpleAzureOpenAIResponses,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "azure-openai-responses",
|
||||
stream: streamAzureOpenAIResponses,
|
||||
streamSimple: streamSimpleAzureOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-codex-responses",
|
||||
stream: streamOpenAICodexResponses,
|
||||
streamSimple: streamSimpleOpenAICodexResponses,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "openai-codex-responses",
|
||||
stream: streamOpenAICodexResponses,
|
||||
streamSimple: streamSimpleOpenAICodexResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-generative-ai",
|
||||
stream: streamGoogle,
|
||||
streamSimple: streamSimpleGoogle,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "google-generative-ai",
|
||||
stream: streamGoogle,
|
||||
streamSimple: streamSimpleGoogle,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-gemini-cli",
|
||||
stream: streamGoogleGeminiCli,
|
||||
streamSimple: streamSimpleGoogleGeminiCli,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "google-gemini-cli",
|
||||
stream: streamGoogleGeminiCli,
|
||||
streamSimple: streamSimpleGoogleGeminiCli,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-vertex",
|
||||
stream: streamGoogleVertex,
|
||||
streamSimple: streamSimpleGoogleVertex,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "google-vertex",
|
||||
stream: streamGoogleVertex,
|
||||
streamSimple: streamSimpleGoogleVertex,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "bedrock-converse-stream",
|
||||
stream: streamBedrock,
|
||||
streamSimple: streamSimpleBedrock,
|
||||
});
|
||||
registerApiProvider({
|
||||
api: "bedrock-converse-stream",
|
||||
stream: streamBedrock,
|
||||
streamSimple: streamSimpleBedrock,
|
||||
});
|
||||
}
|
||||
|
||||
export function resetApiProviders(): void {
|
||||
clearApiProviders();
|
||||
registerBuiltInApiProviders();
|
||||
}
|
||||
|
||||
registerBuiltInApiProviders();
|
||||
|
|
|
|||
|
|
@ -80,3 +80,8 @@ export class AssistantMessageEventStream extends EventStream<AssistantMessageEve
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Factory function for AssistantMessageEventStream (for use in extensions) */
|
||||
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
|
||||
return new AssistantMessageEventStream();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@
|
|||
- Package deduplication: if same package appears in global and project settings, project wins ([#645](https://github.com/badlogic/pi-mono/issues/645))
|
||||
- Unified collision reporting with `ResourceDiagnostic` type for all resource types ([#645](https://github.com/badlogic/pi-mono/issues/645))
|
||||
- Show provider alongside the model in the footer if multiple providers are available
|
||||
- Custom provider support via `pi.registerProvider()` with `streamSimple` for custom API implementations
|
||||
- Added `custom-provider.ts` example extension demonstrating custom Anthropic provider with OAuth
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
|||
547
packages/coding-agent/docs/custom-provider.md
Normal file
547
packages/coding-agent/docs/custom-provider.md
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
# Custom Providers
|
||||
|
||||
Extensions can register custom model providers via `pi.registerProvider()`. This enables:
|
||||
|
||||
- **Proxies** - Route requests through corporate proxies or API gateways
|
||||
- **Custom endpoints** - Use self-hosted or private model deployments
|
||||
- **OAuth/SSO** - Add authentication flows for enterprise providers
|
||||
- **Custom APIs** - Implement streaming for non-standard LLM APIs
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Quick Reference](#quick-reference)
|
||||
- [Override Existing Provider](#override-existing-provider)
|
||||
- [Register New Provider](#register-new-provider)
|
||||
- [OAuth Support](#oauth-support)
|
||||
- [Custom Streaming API](#custom-streaming-api)
|
||||
- [Config Reference](#config-reference)
|
||||
- [Model Definition Reference](#model-definition-reference)
|
||||
|
||||
## Quick Reference
|
||||
|
||||
```typescript
|
||||
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
// Override baseUrl for existing provider
|
||||
pi.registerProvider("anthropic", {
|
||||
baseUrl: "https://proxy.example.com"
|
||||
});
|
||||
|
||||
// Register new provider with models
|
||||
pi.registerProvider("my-provider", {
|
||||
baseUrl: "https://api.example.com",
|
||||
apiKey: "MY_API_KEY",
|
||||
api: "openai-completions",
|
||||
models: [
|
||||
{
|
||||
id: "my-model",
|
||||
name: "My Model",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 4096
|
||||
}
|
||||
]
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
## Override Existing Provider
|
||||
|
||||
The simplest use case: redirect an existing provider through a proxy.
|
||||
|
||||
```typescript
|
||||
// All Anthropic requests now go through your proxy
|
||||
pi.registerProvider("anthropic", {
|
||||
baseUrl: "https://proxy.example.com"
|
||||
});
|
||||
|
||||
// Add custom headers to OpenAI requests
|
||||
pi.registerProvider("openai", {
|
||||
headers: {
|
||||
"X-Custom-Header": "value"
|
||||
}
|
||||
});
|
||||
|
||||
// Both baseUrl and headers
|
||||
pi.registerProvider("google", {
|
||||
baseUrl: "https://ai-gateway.corp.com/google",
|
||||
headers: {
|
||||
"X-Corp-Auth": "CORP_AUTH_TOKEN" // env var or literal
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
When only `baseUrl` and/or `headers` are provided (no `models`), all existing models for that provider are preserved with the new endpoint.
|
||||
|
||||
## Register New Provider
|
||||
|
||||
To add a completely new provider, specify `models` along with the required configuration.
|
||||
|
||||
```typescript
|
||||
pi.registerProvider("my-llm", {
|
||||
baseUrl: "https://api.my-llm.com/v1",
|
||||
apiKey: "MY_LLM_API_KEY", // env var name or literal value
|
||||
api: "openai-completions", // which streaming API to use
|
||||
models: [
|
||||
{
|
||||
id: "my-llm-large",
|
||||
name: "My LLM Large",
|
||||
reasoning: true, // supports extended thinking
|
||||
input: ["text", "image"],
|
||||
cost: {
|
||||
input: 3.0, // $/million tokens
|
||||
output: 15.0,
|
||||
cacheRead: 0.3,
|
||||
cacheWrite: 3.75
|
||||
},
|
||||
contextWindow: 200000,
|
||||
maxTokens: 16384
|
||||
},
|
||||
{
|
||||
id: "my-llm-small",
|
||||
name: "My LLM Small",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0.25, output: 1.25, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192
|
||||
}
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
When `models` is provided, it **replaces** all existing models for that provider.
|
||||
|
||||
### API Types
|
||||
|
||||
The `api` field determines which streaming implementation is used:
|
||||
|
||||
| API | Use for |
|
||||
|-----|---------|
|
||||
| `anthropic-messages` | Anthropic Claude API and compatibles |
|
||||
| `openai-completions` | OpenAI Chat Completions API and compatibles |
|
||||
| `openai-responses` | OpenAI Responses API |
|
||||
| `azure-openai-responses` | Azure OpenAI Responses API |
|
||||
| `openai-codex-responses` | OpenAI Codex Responses API |
|
||||
| `google-generative-ai` | Google Generative AI API |
|
||||
| `google-gemini-cli` | Google Cloud Code Assist API |
|
||||
| `google-vertex` | Google Vertex AI API |
|
||||
| `bedrock-converse-stream` | Amazon Bedrock Converse API |
|
||||
|
||||
Most OpenAI-compatible providers work with `openai-completions`. Use `compat` for quirks:
|
||||
|
||||
```typescript
|
||||
models: [{
|
||||
id: "custom-model",
|
||||
// ...
|
||||
compat: {
|
||||
supportsDeveloperRole: false, // use "system" instead of "developer"
|
||||
supportsReasoningEffort: false, // disable reasoning_effort param
|
||||
maxTokensField: "max_tokens", // instead of "max_completion_tokens"
|
||||
requiresToolResultName: true, // tool results need name field
|
||||
requiresMistralToolIds: true // tool IDs must be 9 alphanumeric chars
|
||||
}
|
||||
}]
|
||||
```
|
||||
|
||||
### Auth Header
|
||||
|
||||
If your provider expects `Authorization: Bearer <key>` but doesn't use a standard API, set `authHeader: true`:
|
||||
|
||||
```typescript
|
||||
pi.registerProvider("custom-api", {
|
||||
baseUrl: "https://api.example.com",
|
||||
apiKey: "MY_API_KEY",
|
||||
authHeader: true, // adds Authorization: Bearer header
|
||||
api: "openai-completions",
|
||||
models: [...]
|
||||
});
|
||||
```
|
||||
|
||||
## OAuth Support
|
||||
|
||||
Add OAuth/SSO authentication that integrates with `/login`:
|
||||
|
||||
```typescript
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai";
|
||||
|
||||
pi.registerProvider("corporate-ai", {
|
||||
baseUrl: "https://ai.corp.com/v1",
|
||||
api: "openai-responses",
|
||||
models: [
|
||||
{
|
||||
id: "corp-claude",
|
||||
name: "Corporate Claude",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 16384
|
||||
}
|
||||
],
|
||||
oauth: {
|
||||
name: "Corporate AI (SSO)",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
// Option 1: Browser-based OAuth
|
||||
callbacks.onAuth({ url: "https://sso.corp.com/authorize?..." });
|
||||
|
||||
// Option 2: Device code flow
|
||||
callbacks.onDeviceCode({
|
||||
userCode: "ABCD-1234",
|
||||
verificationUri: "https://sso.corp.com/device"
|
||||
});
|
||||
|
||||
// Option 3: Prompt for token/code
|
||||
const code = await callbacks.onPrompt({ message: "Enter SSO code:" });
|
||||
|
||||
// Exchange for tokens (your implementation)
|
||||
const tokens = await exchangeCodeForTokens(code);
|
||||
|
||||
return {
|
||||
refresh: tokens.refreshToken,
|
||||
access: tokens.accessToken,
|
||||
expires: Date.now() + tokens.expiresIn * 1000
|
||||
};
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const tokens = await refreshAccessToken(credentials.refresh);
|
||||
return {
|
||||
refresh: tokens.refreshToken ?? credentials.refresh,
|
||||
access: tokens.accessToken,
|
||||
expires: Date.now() + tokens.expiresIn * 1000
|
||||
};
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
|
||||
// Optional: modify models based on user's subscription
|
||||
modifyModels(models, credentials) {
|
||||
// e.g., update baseUrl based on user's region
|
||||
const region = decodeRegionFromToken(credentials.access);
|
||||
return models.map(m => ({
|
||||
...m,
|
||||
baseUrl: `https://${region}.ai.corp.com/v1`
|
||||
}));
|
||||
}
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
After registration, users can authenticate via `/login corporate-ai`.
|
||||
|
||||
### OAuthLoginCallbacks
|
||||
|
||||
The `callbacks` object provides three ways to authenticate:
|
||||
|
||||
```typescript
|
||||
interface OAuthLoginCallbacks {
|
||||
// Open URL in browser (for OAuth redirects)
|
||||
onAuth(params: { url: string }): void;
|
||||
|
||||
// Show device code (for device authorization flow)
|
||||
onDeviceCode(params: { userCode: string; verificationUri: string }): void;
|
||||
|
||||
// Prompt user for input (for manual token entry)
|
||||
onPrompt(params: { message: string }): Promise<string>;
|
||||
}
|
||||
```
|
||||
|
||||
### OAuthCredentials
|
||||
|
||||
Credentials are persisted in `~/.pi/agent/auth.json`:
|
||||
|
||||
```typescript
|
||||
interface OAuthCredentials {
|
||||
refresh: string; // Refresh token (for refreshToken())
|
||||
access: string; // Access token (returned by getApiKey())
|
||||
expires: number; // Expiration timestamp in milliseconds
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Streaming API
|
||||
|
||||
For providers with non-standard APIs, implement `streamSimple`:
|
||||
|
||||
```typescript
|
||||
import type {
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
Api
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { createAssistantMessageEventStream } from "@mariozechner/pi-ai";
|
||||
|
||||
pi.registerProvider("custom-llm", {
|
||||
baseUrl: "https://api.custom-llm.com",
|
||||
apiKey: "CUSTOM_LLM_KEY",
|
||||
api: "custom-llm-api", // your custom API identifier
|
||||
models: [
|
||||
{
|
||||
id: "custom-model",
|
||||
name: "Custom Model",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 32000,
|
||||
maxTokens: 4096
|
||||
}
|
||||
],
|
||||
|
||||
streamSimple(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions
|
||||
): AssistantMessageEventStream {
|
||||
return createAssistantMessageEventStream(async function* (signal) {
|
||||
// Convert context to your API format
|
||||
const messages = context.messages.map(m => ({
|
||||
role: m.role,
|
||||
content: typeof m.content === "string"
|
||||
? m.content
|
||||
: m.content.filter(c => c.type === "text").map(c => c.text).join("")
|
||||
}));
|
||||
|
||||
// Make streaming request
|
||||
const response = await fetch(`${model.baseUrl}/chat`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${options?.apiKey}`,
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true
|
||||
}),
|
||||
signal
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`API error: ${response.status}`);
|
||||
}
|
||||
|
||||
// Yield start event
|
||||
yield { type: "start" };
|
||||
|
||||
// Parse SSE stream
|
||||
const reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let contentIndex = 0;
|
||||
let textStarted = false;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data: ")) continue;
|
||||
const data = line.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
|
||||
const chunk = JSON.parse(data);
|
||||
const delta = chunk.choices?.[0]?.delta?.content;
|
||||
|
||||
if (delta) {
|
||||
if (!textStarted) {
|
||||
yield { type: "text_start", contentIndex };
|
||||
textStarted = true;
|
||||
}
|
||||
yield { type: "text_delta", contentIndex, delta };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (textStarted) {
|
||||
yield { type: "text_end", contentIndex };
|
||||
}
|
||||
|
||||
// Yield usage if available
|
||||
yield {
|
||||
type: "usage",
|
||||
usage: {
|
||||
input: 0, // fill from response if available
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
|
||||
}
|
||||
};
|
||||
|
||||
// Yield done
|
||||
yield { type: "done", reason: "stop" };
|
||||
});
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
### Event Types
|
||||
|
||||
Your generator must yield events in this order:
|
||||
|
||||
1. `{ type: "start" }` - Stream started
|
||||
2. Content events (repeatable, in order):
|
||||
- `{ type: "text_start", contentIndex }` - Text block started
|
||||
- `{ type: "text_delta", contentIndex, delta }` - Text chunk
|
||||
- `{ type: "text_end", contentIndex }` - Text block ended
|
||||
- `{ type: "thinking_start", contentIndex }` - Thinking block started
|
||||
- `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk
|
||||
- `{ type: "thinking_end", contentIndex }` - Thinking block ended
|
||||
- `{ type: "toolcall_start", contentIndex }` - Tool call started
|
||||
- `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk
|
||||
- `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended
|
||||
3. `{ type: "usage", usage }` - Token usage (optional but recommended)
|
||||
4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended
|
||||
|
||||
### Reasoning Support
|
||||
|
||||
For models with extended thinking, yield thinking events:
|
||||
|
||||
```typescript
|
||||
if (chunk.thinking) {
|
||||
if (!thinkingStarted) {
|
||||
yield { type: "thinking_start", contentIndex: thinkingIndex };
|
||||
thinkingStarted = true;
|
||||
}
|
||||
yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking };
|
||||
}
|
||||
```
|
||||
|
||||
### Tool Calls
|
||||
|
||||
For function calling support, yield tool call events:
|
||||
|
||||
```typescript
|
||||
if (chunk.tool_calls) {
|
||||
for (const tc of chunk.tool_calls) {
|
||||
if (tc.index !== currentToolIndex) {
|
||||
if (currentToolIndex >= 0) {
|
||||
yield {
|
||||
type: "toolcall_end",
|
||||
contentIndex: currentToolIndex,
|
||||
toolCall: {
|
||||
type: "toolCall",
|
||||
id: currentToolId,
|
||||
name: currentToolName,
|
||||
arguments: JSON.parse(currentToolArgs)
|
||||
}
|
||||
};
|
||||
}
|
||||
currentToolIndex = tc.index;
|
||||
currentToolId = tc.id;
|
||||
currentToolName = tc.function.name;
|
||||
currentToolArgs = "";
|
||||
yield { type: "toolcall_start", contentIndex: tc.index };
|
||||
}
|
||||
if (tc.function.arguments) {
|
||||
currentToolArgs += tc.function.arguments;
|
||||
yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments };
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Config Reference
|
||||
|
||||
```typescript
|
||||
interface ProviderConfig {
|
||||
/** API endpoint URL. Required when defining models. */
|
||||
baseUrl?: string;
|
||||
|
||||
/** API key or environment variable name. Required when defining models (unless oauth). */
|
||||
apiKey?: string;
|
||||
|
||||
/** API type for streaming. Required at provider or model level when defining models. */
|
||||
api?: Api;
|
||||
|
||||
/** Custom streaming implementation for non-standard APIs. */
|
||||
streamSimple?: (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
/** Custom headers to include in requests. Values can be env var names. */
|
||||
headers?: Record<string, string>;
|
||||
|
||||
/** If true, adds Authorization: Bearer header with the resolved API key. */
|
||||
authHeader?: boolean;
|
||||
|
||||
/** Models to register. If provided, replaces all existing models for this provider. */
|
||||
models?: ProviderModelConfig[];
|
||||
|
||||
/** OAuth provider for /login support. */
|
||||
oauth?: {
|
||||
name: string;
|
||||
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||
getApiKey(credentials: OAuthCredentials): string;
|
||||
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
## Model Definition Reference
|
||||
|
||||
```typescript
|
||||
interface ProviderModelConfig {
|
||||
/** Model ID (e.g., "claude-sonnet-4-20250514"). */
|
||||
id: string;
|
||||
|
||||
/** Display name (e.g., "Claude 4 Sonnet"). */
|
||||
name: string;
|
||||
|
||||
/** API type override for this specific model. */
|
||||
api?: Api;
|
||||
|
||||
/** Whether the model supports extended thinking. */
|
||||
reasoning: boolean;
|
||||
|
||||
/** Supported input types. */
|
||||
input: ("text" | "image")[];
|
||||
|
||||
/** Cost per million tokens (for usage tracking). */
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
|
||||
/** Maximum context window size in tokens. */
|
||||
contextWindow: number;
|
||||
|
||||
/** Maximum output tokens. */
|
||||
maxTokens: number;
|
||||
|
||||
/** Custom headers for this specific model. */
|
||||
headers?: Record<string, string>;
|
||||
|
||||
/** OpenAI compatibility settings for openai-completions API. */
|
||||
compat?: {
|
||||
supportsStore?: boolean;
|
||||
supportsDeveloperRole?: boolean;
|
||||
supportsReasoningEffort?: boolean;
|
||||
supportsUsageInStreaming?: boolean;
|
||||
maxTokensField?: "max_completion_tokens" | "max_tokens";
|
||||
requiresToolResultName?: boolean;
|
||||
requiresAssistantAfterToolResult?: boolean;
|
||||
requiresThinkingAsText?: boolean;
|
||||
requiresMistralToolIds?: boolean;
|
||||
thinkingFormat?: "openai" | "zai";
|
||||
};
|
||||
}
|
||||
```
|
||||
|
|
@ -1206,6 +1206,9 @@ pi.registerProvider("corporate-ai", {
|
|||
- `authHeader` - If true, adds `Authorization: Bearer` header automatically.
|
||||
- `models` - Array of model definitions. If provided, replaces all existing models for this provider.
|
||||
- `oauth` - OAuth provider config for `/login` support. When provided, the provider appears in the login menu.
|
||||
- `streamSimple` - Custom streaming implementation for non-standard APIs.
|
||||
|
||||
See [custom-provider.md](custom-provider.md) for advanced topics: custom streaming APIs, OAuth details, model definition reference.
|
||||
|
||||
## State Management
|
||||
|
||||
|
|
|
|||
601
packages/coding-agent/examples/extensions/custom-provider.ts
Normal file
601
packages/coding-agent/examples/extensions/custom-provider.ts
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
/**
|
||||
* Custom Provider Example
|
||||
*
|
||||
* Demonstrates registering a custom provider with:
|
||||
* - Custom API identifier ("custom-anthropic-api")
|
||||
* - Custom streamSimple implementation
|
||||
* - OAuth support for /login
|
||||
* - API key support via environment variable
|
||||
* - Two model definitions
|
||||
*
|
||||
* Usage:
|
||||
* # With OAuth (run /login custom-anthropic first)
|
||||
* pi -e ./custom-provider.ts
|
||||
*
|
||||
* # With API key
|
||||
* CUSTOM_ANTHROPIC_API_KEY=sk-ant-... pi -e ./custom-provider.ts
|
||||
*
|
||||
* Then use /model to select custom-anthropic/claude-sonnet-4-5
|
||||
*/
|
||||
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type { ContentBlockParam, MessageCreateParamsStreaming } from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
calculateCost,
|
||||
createAssistantMessageEventStream,
|
||||
type ImageContent,
|
||||
type Message,
|
||||
type Model,
|
||||
type OAuthCredentials,
|
||||
type OAuthLoginCallbacks,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type TextContent,
|
||||
type ThinkingContent,
|
||||
type Tool,
|
||||
type ToolCall,
|
||||
type ToolResultMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// =============================================================================
|
||||
// OAuth Implementation (copied from packages/ai/src/utils/oauth/anthropic.ts)
|
||||
// =============================================================================
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
|
||||
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
|
||||
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
|
||||
const SCOPES = "org:create_api_key user:profile user:inference";
|
||||
|
||||
async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
|
||||
const array = new Uint8Array(32);
|
||||
crypto.getRandomValues(array);
|
||||
const verifier = btoa(String.fromCharCode(...array))
|
||||
.replace(/\+/g, "-")
|
||||
.replace(/\//g, "_")
|
||||
.replace(/=+$/, "");
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const hash = await crypto.subtle.digest("SHA-256", data);
|
||||
const challenge = btoa(String.fromCharCode(...new Uint8Array(hash)))
|
||||
.replace(/\+/g, "-")
|
||||
.replace(/\//g, "_")
|
||||
.replace(/=+$/, "");
|
||||
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
||||
async function loginAnthropic(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
const authParams = new URLSearchParams({
|
||||
code: "true",
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES,
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
});
|
||||
|
||||
callbacks.onAuth({ url: `${AUTHORIZE_URL}?${authParams.toString()}` });
|
||||
|
||||
const authCode = await callbacks.onPrompt({ message: "Paste the authorization code:" });
|
||||
const [code, state] = authCode.split("#");
|
||||
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code,
|
||||
state,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error(`Token exchange failed: ${await tokenResponse.text()}`);
|
||||
}
|
||||
|
||||
const data = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
async function refreshAnthropicToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
grant_type: "refresh_token",
|
||||
client_id: CLIENT_ID,
|
||||
refresh_token: credentials.refresh,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Token refresh failed: ${await response.text()}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Streaming Implementation (simplified from packages/ai/src/providers/anthropic.ts)
|
||||
// =============================================================================
|
||||
|
||||
// Claude Code tool names for OAuth stealth mode
|
||||
const claudeCodeTools = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Bash",
|
||||
"Grep",
|
||||
"Glob",
|
||||
"AskUserQuestion",
|
||||
"TodoWrite",
|
||||
"WebFetch",
|
||||
"WebSearch",
|
||||
];
|
||||
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
|
||||
const toClaudeCodeName = (name: string) => ccToolLookup.get(name.toLowerCase()) ?? name;
|
||||
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
|
||||
const lowerName = name.toLowerCase();
|
||||
const matched = tools?.find((t) => t.name.toLowerCase() === lowerName);
|
||||
return matched?.name ?? name;
|
||||
};
|
||||
|
||||
function isOAuthToken(apiKey: string): boolean {
|
||||
return apiKey.includes("sk-ant-oat");
|
||||
}
|
||||
|
||||
function sanitizeSurrogates(text: string): string {
|
||||
return text.replace(/[\uD800-\uDFFF]/g, "\uFFFD");
|
||||
}
|
||||
|
||||
function convertContentBlocks(
|
||||
content: (TextContent | ImageContent)[],
|
||||
): string | Array<{ type: "text"; text: string } | { type: "image"; source: any }> {
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (!hasImages) {
|
||||
return sanitizeSurrogates(content.map((c) => (c as TextContent).text).join("\n"));
|
||||
}
|
||||
|
||||
const blocks = content.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return { type: "text" as const, text: sanitizeSurrogates(block.text) };
|
||||
}
|
||||
return {
|
||||
type: "image" as const,
|
||||
source: {
|
||||
type: "base64" as const,
|
||||
media_type: block.mimeType,
|
||||
data: block.data,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
if (!blocks.some((b) => b.type === "text")) {
|
||||
blocks.unshift({ type: "text" as const, text: "(see attached image)" });
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function convertMessages(messages: Message[], isOAuth: boolean, _tools?: Tool[]): any[] {
|
||||
const params: any[] = [];
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i];
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim()) {
|
||||
params.push({ role: "user", content: sanitizeSurrogates(msg.content) });
|
||||
}
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) =>
|
||||
item.type === "text"
|
||||
? { type: "text" as const, text: sanitizeSurrogates(item.text) }
|
||||
: {
|
||||
type: "image" as const,
|
||||
source: { type: "base64" as const, media_type: item.mimeType as any, data: item.data },
|
||||
},
|
||||
);
|
||||
if (blocks.length > 0) {
|
||||
params.push({ role: "user", content: blocks });
|
||||
}
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text" && block.text.trim()) {
|
||||
blocks.push({ type: "text", text: sanitizeSurrogates(block.text) });
|
||||
} else if (block.type === "thinking" && block.thinking.trim()) {
|
||||
if ((block as ThinkingContent).thinkingSignature) {
|
||||
blocks.push({
|
||||
type: "thinking" as any,
|
||||
thinking: sanitizeSurrogates(block.thinking),
|
||||
signature: (block as ThinkingContent).thinkingSignature!,
|
||||
});
|
||||
} else {
|
||||
blocks.push({ type: "text", text: sanitizeSurrogates(block.thinking) });
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: isOAuth ? toClaudeCodeName(block.name) : block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (blocks.length > 0) {
|
||||
params.push({ role: "assistant", content: blocks });
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const toolResults: any[] = [];
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: convertContentBlocks(msg.content),
|
||||
is_error: msg.isError,
|
||||
});
|
||||
|
||||
let j = i + 1;
|
||||
while (j < messages.length && messages[j].role === "toolResult") {
|
||||
const nextMsg = messages[j] as ToolResultMessage;
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: nextMsg.toolCallId,
|
||||
content: convertContentBlocks(nextMsg.content),
|
||||
is_error: nextMsg.isError,
|
||||
});
|
||||
j++;
|
||||
}
|
||||
i = j - 1;
|
||||
params.push({ role: "user", content: toolResults });
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache control to last user message
|
||||
if (params.length > 0) {
|
||||
const last = params[params.length - 1];
|
||||
if (last.role === "user" && Array.isArray(last.content)) {
|
||||
const lastBlock = last.content[last.content.length - 1];
|
||||
if (lastBlock) {
|
||||
lastBlock.cache_control = { type: "ephemeral" };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[], isOAuth: boolean): any[] {
|
||||
return tools.map((tool) => ({
|
||||
name: isOAuth ? toClaudeCodeName(tool.name) : tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: (tool.parameters as any).properties || {},
|
||||
required: (tool.parameters as any).required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(reason: string): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
case "pause_turn":
|
||||
case "stop_sequence":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
function streamCustomAnthropic(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const stream = createAssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey ?? "";
|
||||
const isOAuth = isOAuthToken(apiKey);
|
||||
|
||||
// Configure client based on auth type
|
||||
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14", "interleaved-thinking-2025-05-14"];
|
||||
const clientOptions: any = {
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
};
|
||||
|
||||
if (isOAuth) {
|
||||
clientOptions.apiKey = null;
|
||||
clientOptions.authToken = apiKey;
|
||||
clientOptions.defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
|
||||
"user-agent": "claude-cli/2.1.2 (external, cli)",
|
||||
"x-app": "cli",
|
||||
};
|
||||
} else {
|
||||
clientOptions.apiKey = apiKey;
|
||||
clientOptions.defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": betaFeatures.join(","),
|
||||
};
|
||||
}
|
||||
|
||||
const client = new Anthropic(clientOptions);
|
||||
|
||||
// Build request params
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context.messages, isOAuth, context.tools),
|
||||
max_tokens: options?.maxTokens || Math.floor(model.maxTokens / 3),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// System prompt with Claude Code identity for OAuth
|
||||
if (isOAuth) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: { type: "ephemeral" },
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
cache_control: { type: "ephemeral" },
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
cache_control: { type: "ephemeral" },
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, isOAuth);
|
||||
}
|
||||
|
||||
// Handle thinking/reasoning
|
||||
if (options?.reasoning && model.reasoning) {
|
||||
const defaultBudgets: Record<string, number> = {
|
||||
minimal: 1024,
|
||||
low: 4096,
|
||||
medium: 10240,
|
||||
high: 20480,
|
||||
};
|
||||
const customBudget = options.thinkingBudgets?.[options.reasoning as keyof typeof options.thinkingBudgets];
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: customBudget ?? defaultBudgets[options.reasoning] ?? 10240,
|
||||
};
|
||||
}
|
||||
|
||||
const anthropicStream = client.messages.stream({ ...params }, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "message_start") {
|
||||
output.usage.input = event.message.usage.input_tokens || 0;
|
||||
output.usage.output = event.message.usage.output_tokens || 0;
|
||||
output.usage.cacheRead = (event.message.usage as any).cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite = (event.message.usage as any).cache_creation_input_tokens || 0;
|
||||
output.usage.totalTokens =
|
||||
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
} else if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
output.content.push({ type: "text", text: "", index: event.index } as any);
|
||||
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
output.content.push({
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: event.index,
|
||||
} as any);
|
||||
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
output.content.push({
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: isOAuth
|
||||
? fromClaudeCodeName(event.content_block.name, context.tools)
|
||||
: event.content_block.name,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
index: event.index,
|
||||
} as any);
|
||||
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (!block) continue;
|
||||
|
||||
if (event.delta.type === "text_delta" && block.type === "text") {
|
||||
block.text += event.delta.text;
|
||||
stream.push({ type: "text_delta", contentIndex: index, delta: event.delta.text, partial: output });
|
||||
} else if (event.delta.type === "thinking_delta" && block.type === "thinking") {
|
||||
block.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.delta.type === "input_json_delta" && block.type === "toolCall") {
|
||||
(block as any).partialJson += event.delta.partial_json;
|
||||
try {
|
||||
block.arguments = JSON.parse((block as any).partialJson);
|
||||
} catch {}
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.partial_json,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.delta.type === "signature_delta" && block.type === "thinking") {
|
||||
block.thinkingSignature = (block.thinkingSignature || "") + (event.delta as any).signature;
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (!block) continue;
|
||||
|
||||
delete (block as any).index;
|
||||
if (block.type === "text") {
|
||||
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output });
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output });
|
||||
} else if (block.type === "toolCall") {
|
||||
try {
|
||||
block.arguments = JSON.parse((block as any).partialJson);
|
||||
} catch {}
|
||||
delete (block as any).partialJson;
|
||||
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output });
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if ((event.delta as any).stop_reason) {
|
||||
output.stopReason = mapStopReason((event.delta as any).stop_reason);
|
||||
}
|
||||
output.usage.input = (event.usage as any).input_tokens || 0;
|
||||
output.usage.output = (event.usage as any).output_tokens || 0;
|
||||
output.usage.cacheRead = (event.usage as any).cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite = (event.usage as any).cache_creation_input_tokens || 0;
|
||||
output.usage.totalTokens =
|
||||
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason as "stop" | "length" | "toolUse", message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Extension Entry Point
|
||||
// =============================================================================
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.registerProvider("custom-anthropic", {
|
||||
baseUrl: "https://api.anthropic.com",
|
||||
apiKey: "CUSTOM_ANTHROPIC_API_KEY",
|
||||
api: "custom-anthropic-api",
|
||||
|
||||
models: [
|
||||
{
|
||||
id: "claude-opus-4-5",
|
||||
name: "Claude Opus 4.5 (Custom)",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 5, output: 25, cacheRead: 0.5, cacheWrite: 6.25 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 64000,
|
||||
},
|
||||
{
|
||||
id: "claude-sonnet-4-5",
|
||||
name: "Claude Sonnet 4.5 (Custom)",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 3, output: 15, cacheRead: 0.3, cacheWrite: 3.75 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 64000,
|
||||
},
|
||||
],
|
||||
|
||||
oauth: {
|
||||
name: "Custom Anthropic (Claude Pro/Max)",
|
||||
login: loginAnthropic,
|
||||
refreshToken: refreshAnthropicToken,
|
||||
getApiKey: (cred) => cred.access,
|
||||
},
|
||||
|
||||
streamSimple: streamCustomAnthropic,
|
||||
});
|
||||
}
|
||||
|
|
@ -23,7 +23,7 @@ import type {
|
|||
ThinkingLevel,
|
||||
} from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, ImageContent, Message, Model, TextContent } from "@mariozechner/pi-ai";
|
||||
import { isContextOverflow, modelsAreEqual, supportsXhigh } from "@mariozechner/pi-ai";
|
||||
import { isContextOverflow, modelsAreEqual, resetApiProviders, supportsXhigh } from "@mariozechner/pi-ai";
|
||||
import { getAuthPath } from "../config.js";
|
||||
import { theme } from "../modes/interactive/theme/theme.js";
|
||||
import { stripFrontmatter } from "../utils/frontmatter.js";
|
||||
|
|
@ -1832,6 +1832,7 @@ export class AgentSession {
|
|||
async reload(): Promise<void> {
|
||||
const previousFlagValues = this._extensionRunner?.getFlagValues();
|
||||
await this._extensionRunner?.emit({ type: "session_shutdown" });
|
||||
resetApiProviders();
|
||||
await this._resourceLoader.reload();
|
||||
this._buildRuntime({
|
||||
activeToolNames: this.getActiveToolNames(),
|
||||
|
|
|
|||
|
|
@ -16,10 +16,13 @@ import type {
|
|||
} from "@mariozechner/pi-agent-core";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
SimpleStreamOptions,
|
||||
TextContent,
|
||||
ToolResultMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
|
|
@ -872,6 +875,7 @@ export interface ExtensionAPI {
|
|||
* If `models` is provided: replaces all existing models for this provider.
|
||||
* If only `baseUrl` is provided: overrides the URL for existing models.
|
||||
* If `oauth` is provided: registers OAuth provider for /login support.
|
||||
* If `streamSimple` is provided: registers a custom API stream handler.
|
||||
*
|
||||
* @example
|
||||
* // Register a new provider with custom models
|
||||
|
|
@ -930,6 +934,8 @@ export interface ProviderConfig {
|
|||
apiKey?: string;
|
||||
/** API type. Required at provider or model level when defining models. */
|
||||
api?: Api;
|
||||
/** Optional streamSimple handler for custom APIs. */
|
||||
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
|
||||
/** Custom headers to include in requests. */
|
||||
headers?: Record<string, string>;
|
||||
/** If true, adds Authorization: Bearer header with the resolved API key. */
|
||||
|
|
|
|||
|
|
@ -4,12 +4,16 @@
|
|||
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
type OAuthProviderInterface,
|
||||
registerApiProvider,
|
||||
registerOAuthProvider,
|
||||
type SimpleStreamOptions,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
|
|
@ -45,17 +49,7 @@ const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResp
|
|||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("azure-openai-responses"),
|
||||
Type.Literal("openai-codex-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
Type.Literal("bedrock-converse-stream"),
|
||||
]),
|
||||
),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Boolean(),
|
||||
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
cost: Type.Object({
|
||||
|
|
@ -73,17 +67,7 @@ const ModelDefinitionSchema = Type.Object({
|
|||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
apiKey: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("azure-openai-responses"),
|
||||
Type.Literal("openai-codex-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
Type.Literal("bedrock-converse-stream"),
|
||||
]),
|
||||
),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
|
||||
|
|
@ -482,6 +466,18 @@ export class ModelRegistry {
|
|||
registerOAuthProvider(oauthProvider);
|
||||
}
|
||||
|
||||
if (config.streamSimple) {
|
||||
if (!config.api) {
|
||||
throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`);
|
||||
}
|
||||
const streamSimple = config.streamSimple;
|
||||
registerApiProvider({
|
||||
api: config.api,
|
||||
stream: (model, context, options) => streamSimple(model, context, options as SimpleStreamOptions),
|
||||
streamSimple,
|
||||
});
|
||||
}
|
||||
|
||||
// Store API key for auth resolution
|
||||
if (config.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, config.apiKey);
|
||||
|
|
@ -556,6 +552,7 @@ export interface ProviderConfigInput {
|
|||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: Api;
|
||||
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
/** OAuth provider for /login support */
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue