diff --git a/packages/coding-agent/docs/custom-provider.md b/packages/coding-agent/docs/custom-provider.md index acb6c78c..4cf12609 100644 --- a/packages/coding-agent/docs/custom-provider.md +++ b/packages/coding-agent/docs/custom-provider.md @@ -14,6 +14,7 @@ Extensions can register custom model providers via `pi.registerProvider()`. This - [Register New Provider](#register-new-provider) - [OAuth Support](#oauth-support) - [Custom Streaming API](#custom-streaming-api) +- [Testing Your Implementation](#testing-your-implementation) - [Config Reference](#config-reference) - [Model Definition Reference](#model-definition-reference) @@ -99,15 +100,6 @@ pi.registerProvider("my-llm", { }, contextWindow: 200000, maxTokens: 16384 - }, - { - id: "my-llm-small", - name: "My LLM Small", - reasoning: false, - input: ["text"], - cost: { input: 0.25, output: 1.25, cacheRead: 0, cacheWrite: 0 }, - contextWindow: 128000, - maxTokens: 8192 } ] }); @@ -171,17 +163,7 @@ import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai" pi.registerProvider("corporate-ai", { baseUrl: "https://ai.corp.com/v1", api: "openai-responses", - models: [ - { - id: "corp-claude", - name: "Corporate Claude", - reasoning: true, - input: ["text", "image"], - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, - contextWindow: 200000, - maxTokens: 16384 - } - ], + models: [...], oauth: { name: "Corporate AI (SSO)", @@ -223,7 +205,6 @@ pi.registerProvider("corporate-ai", { // Optional: modify models based on user's subscription modifyModels(models, credentials) { - // e.g., update baseUrl based on user's region const region = decodeRegionFromToken(credentials.access); return models.map(m => ({ ...m, @@ -267,193 +248,203 @@ interface OAuthCredentials { ## Custom Streaming API -For providers with non-standard APIs, implement `streamSimple`: +For providers with non-standard APIs, implement `streamSimple`. Study the existing provider implementations before writing your own: + +**Reference implementations:** +- [anthropic.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/anthropic.ts) - Anthropic Messages API +- [openai-completions.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-completions.ts) - OpenAI Chat Completions +- [openai-responses.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-responses.ts) - OpenAI Responses API +- [google.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/google.ts) - Google Generative AI +- [amazon-bedrock.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/amazon-bedrock.ts) - AWS Bedrock + +### Stream Pattern + +All providers follow the same pattern: ```typescript -import type { - AssistantMessageEventStream, - Context, - Model, - SimpleStreamOptions, - Api +import { + type AssistantMessage, + type AssistantMessageEventStream, + type Context, + type Model, + type SimpleStreamOptions, + calculateCost, + createAssistantMessageEventStream, } from "@mariozechner/pi-ai"; -import { createAssistantMessageEventStream } from "@mariozechner/pi-ai"; -pi.registerProvider("custom-llm", { - baseUrl: "https://api.custom-llm.com", - apiKey: "CUSTOM_LLM_KEY", - api: "custom-llm-api", // your custom API identifier - models: [ - { - id: "custom-model", - name: "Custom Model", - reasoning: false, - input: ["text"], - cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 }, - contextWindow: 32000, - maxTokens: 4096 - } - ], +function streamMyProvider( + model: Model, + context: Context, + options?: SimpleStreamOptions +): AssistantMessageEventStream { + const stream = createAssistantMessageEventStream(); - streamSimple( - model: Model, - context: Context, - options?: SimpleStreamOptions - ): AssistantMessageEventStream { - return createAssistantMessageEventStream(async function* (signal) { - // Convert context to your API format - const messages = context.messages.map(m => ({ - role: m.role, - content: typeof m.content === "string" - ? m.content - : m.content.filter(c => c.type === "text").map(c => c.text).join("") - })); + (async () => { + // Initialize output message + const output: AssistantMessage = { + role: "assistant", + content: [], + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; - // Make streaming request - const response = await fetch(`${model.baseUrl}/chat`, { - method: "POST", - headers: { - "Authorization": `Bearer ${options?.apiKey}`, - "Content-Type": "application/json" - }, - body: JSON.stringify({ - model: model.id, - messages, - stream: true - }), - signal + try { + // Push start event + stream.push({ type: "start", partial: output }); + + // Make API request and process response... + // Push content events as they arrive... + + // Push done event + stream.push({ + type: "done", + reason: output.stopReason as "stop" | "length" | "toolUse", + message: output }); + stream.end(); + } catch (error) { + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : String(error); + stream.push({ type: "error", reason: output.stopReason, error: output }); + stream.end(); + } + })(); - if (!response.ok) { - throw new Error(`API error: ${response.status}`); - } - - // Yield start event - yield { type: "start" }; - - // Parse SSE stream - const reader = response.body!.getReader(); - const decoder = new TextDecoder(); - let buffer = ""; - let contentIndex = 0; - let textStarted = false; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop() ?? ""; - - for (const line of lines) { - if (!line.startsWith("data: ")) continue; - const data = line.slice(6); - if (data === "[DONE]") continue; - - const chunk = JSON.parse(data); - const delta = chunk.choices?.[0]?.delta?.content; - - if (delta) { - if (!textStarted) { - yield { type: "text_start", contentIndex }; - textStarted = true; - } - yield { type: "text_delta", contentIndex, delta }; - } - } - } - - if (textStarted) { - yield { type: "text_end", contentIndex }; - } - - // Yield usage if available - yield { - type: "usage", - usage: { - input: 0, // fill from response if available - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } - } - }; - - // Yield done - yield { type: "done", reason: "stop" }; - }); - } -}); + return stream; +} ``` ### Event Types -Your generator must yield events in this order: +Push events via `stream.push()` in this order: -1. `{ type: "start" }` - Stream started -2. Content events (repeatable, in order): - - `{ type: "text_start", contentIndex }` - Text block started - - `{ type: "text_delta", contentIndex, delta }` - Text chunk - - `{ type: "text_end", contentIndex }` - Text block ended - - `{ type: "thinking_start", contentIndex }` - Thinking block started - - `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk - - `{ type: "thinking_end", contentIndex }` - Thinking block ended - - `{ type: "toolcall_start", contentIndex }` - Tool call started - - `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk - - `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended -3. `{ type: "usage", usage }` - Token usage (optional but recommended) -4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended +1. `{ type: "start", partial: output }` - Stream started -### Reasoning Support +2. Content events (repeatable, track `contentIndex` for each block): + - `{ type: "text_start", contentIndex, partial }` - Text block started + - `{ type: "text_delta", contentIndex, delta, partial }` - Text chunk + - `{ type: "text_end", contentIndex, content, partial }` - Text block ended + - `{ type: "thinking_start", contentIndex, partial }` - Thinking started + - `{ type: "thinking_delta", contentIndex, delta, partial }` - Thinking chunk + - `{ type: "thinking_end", contentIndex, content, partial }` - Thinking ended + - `{ type: "toolcall_start", contentIndex, partial }` - Tool call started + - `{ type: "toolcall_delta", contentIndex, delta, partial }` - Tool call JSON chunk + - `{ type: "toolcall_end", contentIndex, toolCall, partial }` - Tool call ended -For models with extended thinking, yield thinking events: +3. `{ type: "done", reason, message }` or `{ type: "error", reason, error }` - Stream ended + +The `partial` field in each event contains the current `AssistantMessage` state. Update `output.content` as you receive data, then include `output` as the `partial`. + +### Content Blocks + +Add content blocks to `output.content` as they arrive: ```typescript -if (chunk.thinking) { - if (!thinkingStarted) { - yield { type: "thinking_start", contentIndex: thinkingIndex }; - thinkingStarted = true; - } - yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking }; +// Text block +output.content.push({ type: "text", text: "" }); +stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output }); + +// As text arrives +const block = output.content[contentIndex]; +if (block.type === "text") { + block.text += delta; + stream.push({ type: "text_delta", contentIndex, delta, partial: output }); } + +// When block completes +stream.push({ type: "text_end", contentIndex, content: block.text, partial: output }); ``` ### Tool Calls -For function calling support, yield tool call events: +Tool calls require accumulating JSON and parsing: ```typescript -if (chunk.tool_calls) { - for (const tc of chunk.tool_calls) { - if (tc.index !== currentToolIndex) { - if (currentToolIndex >= 0) { - yield { - type: "toolcall_end", - contentIndex: currentToolIndex, - toolCall: { - type: "toolCall", - id: currentToolId, - name: currentToolName, - arguments: JSON.parse(currentToolArgs) - } - }; - } - currentToolIndex = tc.index; - currentToolId = tc.id; - currentToolName = tc.function.name; - currentToolArgs = ""; - yield { type: "toolcall_start", contentIndex: tc.index }; - } - if (tc.function.arguments) { - currentToolArgs += tc.function.arguments; - yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments }; - } - } -} +// Start tool call +output.content.push({ + type: "toolCall", + id: toolCallId, + name: toolName, + arguments: {} +}); +stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output }); + +// Accumulate JSON +let partialJson = ""; +partialJson += jsonDelta; +try { + block.arguments = JSON.parse(partialJson); +} catch {} +stream.push({ type: "toolcall_delta", contentIndex, delta: jsonDelta, partial: output }); + +// Complete +stream.push({ + type: "toolcall_end", + contentIndex, + toolCall: { type: "toolCall", id, name, arguments: block.arguments }, + partial: output +}); ``` +### Usage and Cost + +Update usage from API response and calculate cost: + +```typescript +output.usage.input = response.usage.input_tokens; +output.usage.output = response.usage.output_tokens; +output.usage.cacheRead = response.usage.cache_read_tokens ?? 0; +output.usage.cacheWrite = response.usage.cache_write_tokens ?? 0; +output.usage.totalTokens = output.usage.input + output.usage.output + + output.usage.cacheRead + output.usage.cacheWrite; +calculateCost(model, output.usage); +``` + +### Registration + +Register your stream function: + +```typescript +pi.registerProvider("my-provider", { + baseUrl: "https://api.example.com", + apiKey: "MY_API_KEY", + api: "my-custom-api", + models: [...], + streamSimple: streamMyProvider +}); +``` + +## Testing Your Implementation + +Test your provider against the same test suites used by built-in providers. Copy and adapt these test files from [packages/ai/test/](https://github.com/badlogic/pi-mono/tree/main/packages/ai/test): + +| Test | Purpose | +|------|---------| +| `stream.test.ts` | Basic streaming, text output | +| `tokens.test.ts` | Token counting and usage | +| `abort.test.ts` | AbortSignal handling | +| `empty.test.ts` | Empty/minimal responses | +| `context-overflow.test.ts` | Context window limits | +| `image-limits.test.ts` | Image input handling | +| `unicode-surrogate.test.ts` | Unicode edge cases | +| `tool-call-without-result.test.ts` | Tool call edge cases | +| `image-tool-result.test.ts` | Images in tool results | +| `total-tokens.test.ts` | Total token calculation | +| `cross-provider-handoff.test.ts` | Context handoff between providers | + +Run tests with your provider/model pairs to verify compatibility. + ## Config Reference ```typescript