docs: rewrite custom-provider.md with actual stream pattern

- Document stream.push/end pattern instead of yield
- Link to existing provider implementations on GitHub
- Add testing section with list of test files to run
- Include proper content block and tool call examples
This commit is contained in:
Mario Zechner 2026-01-24 23:18:36 +01:00
parent 177c694406
commit d4bd1a956b

View file

@ -14,6 +14,7 @@ Extensions can register custom model providers via `pi.registerProvider()`. This
- [Register New Provider](#register-new-provider)
- [OAuth Support](#oauth-support)
- [Custom Streaming API](#custom-streaming-api)
- [Testing Your Implementation](#testing-your-implementation)
- [Config Reference](#config-reference)
- [Model Definition Reference](#model-definition-reference)
@ -99,15 +100,6 @@ pi.registerProvider("my-llm", {
},
contextWindow: 200000,
maxTokens: 16384
},
{
id: "my-llm-small",
name: "My LLM Small",
reasoning: false,
input: ["text"],
cost: { input: 0.25, output: 1.25, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 8192
}
]
});
@ -171,17 +163,7 @@ import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai"
pi.registerProvider("corporate-ai", {
baseUrl: "https://ai.corp.com/v1",
api: "openai-responses",
models: [
{
id: "corp-claude",
name: "Corporate Claude",
reasoning: true,
input: ["text", "image"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 200000,
maxTokens: 16384
}
],
models: [...],
oauth: {
name: "Corporate AI (SSO)",
@ -223,7 +205,6 @@ pi.registerProvider("corporate-ai", {
// Optional: modify models based on user's subscription
modifyModels(models, credentials) {
// e.g., update baseUrl based on user's region
const region = decodeRegionFromToken(credentials.access);
return models.map(m => ({
...m,
@ -267,193 +248,203 @@ interface OAuthCredentials {
## Custom Streaming API
For providers with non-standard APIs, implement `streamSimple`:
For providers with non-standard APIs, implement `streamSimple`. Study the existing provider implementations before writing your own:
**Reference implementations:**
- [anthropic.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/anthropic.ts) - Anthropic Messages API
- [openai-completions.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-completions.ts) - OpenAI Chat Completions
- [openai-responses.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-responses.ts) - OpenAI Responses API
- [google.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/google.ts) - Google Generative AI
- [amazon-bedrock.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/amazon-bedrock.ts) - AWS Bedrock
### Stream Pattern
All providers follow the same pattern:
```typescript
import type {
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
Api
import {
type AssistantMessage,
type AssistantMessageEventStream,
type Context,
type Model,
type SimpleStreamOptions,
calculateCost,
createAssistantMessageEventStream,
} from "@mariozechner/pi-ai";
import { createAssistantMessageEventStream } from "@mariozechner/pi-ai";
pi.registerProvider("custom-llm", {
baseUrl: "https://api.custom-llm.com",
apiKey: "CUSTOM_LLM_KEY",
api: "custom-llm-api", // your custom API identifier
models: [
{
id: "custom-model",
name: "Custom Model",
reasoning: false,
input: ["text"],
cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 32000,
maxTokens: 4096
}
],
function streamMyProvider(
model: Model<any>,
context: Context,
options?: SimpleStreamOptions
): AssistantMessageEventStream {
const stream = createAssistantMessageEventStream();
streamSimple(
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions
): AssistantMessageEventStream {
return createAssistantMessageEventStream(async function* (signal) {
// Convert context to your API format
const messages = context.messages.map(m => ({
role: m.role,
content: typeof m.content === "string"
? m.content
: m.content.filter(c => c.type === "text").map(c => c.text).join("")
}));
(async () => {
// Initialize output message
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
// Make streaming request
const response = await fetch(`${model.baseUrl}/chat`, {
method: "POST",
headers: {
"Authorization": `Bearer ${options?.apiKey}`,
"Content-Type": "application/json"
},
body: JSON.stringify({
model: model.id,
messages,
stream: true
}),
signal
try {
// Push start event
stream.push({ type: "start", partial: output });
// Make API request and process response...
// Push content events as they arrive...
// Push done event
stream.push({
type: "done",
reason: output.stopReason as "stop" | "length" | "toolUse",
message: output
});
stream.end();
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : String(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
if (!response.ok) {
throw new Error(`API error: ${response.status}`);
}
// Yield start event
yield { type: "start" };
// Parse SSE stream
const reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
let contentIndex = 0;
let textStarted = false;
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() ?? "";
for (const line of lines) {
if (!line.startsWith("data: ")) continue;
const data = line.slice(6);
if (data === "[DONE]") continue;
const chunk = JSON.parse(data);
const delta = chunk.choices?.[0]?.delta?.content;
if (delta) {
if (!textStarted) {
yield { type: "text_start", contentIndex };
textStarted = true;
}
yield { type: "text_delta", contentIndex, delta };
}
}
}
if (textStarted) {
yield { type: "text_end", contentIndex };
}
// Yield usage if available
yield {
type: "usage",
usage: {
input: 0, // fill from response if available
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
}
};
// Yield done
yield { type: "done", reason: "stop" };
});
}
});
return stream;
}
```
### Event Types
Your generator must yield events in this order:
Push events via `stream.push()` in this order:
1. `{ type: "start" }` - Stream started
2. Content events (repeatable, in order):
- `{ type: "text_start", contentIndex }` - Text block started
- `{ type: "text_delta", contentIndex, delta }` - Text chunk
- `{ type: "text_end", contentIndex }` - Text block ended
- `{ type: "thinking_start", contentIndex }` - Thinking block started
- `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk
- `{ type: "thinking_end", contentIndex }` - Thinking block ended
- `{ type: "toolcall_start", contentIndex }` - Tool call started
- `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk
- `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended
3. `{ type: "usage", usage }` - Token usage (optional but recommended)
4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended
1. `{ type: "start", partial: output }` - Stream started
### Reasoning Support
2. Content events (repeatable, track `contentIndex` for each block):
- `{ type: "text_start", contentIndex, partial }` - Text block started
- `{ type: "text_delta", contentIndex, delta, partial }` - Text chunk
- `{ type: "text_end", contentIndex, content, partial }` - Text block ended
- `{ type: "thinking_start", contentIndex, partial }` - Thinking started
- `{ type: "thinking_delta", contentIndex, delta, partial }` - Thinking chunk
- `{ type: "thinking_end", contentIndex, content, partial }` - Thinking ended
- `{ type: "toolcall_start", contentIndex, partial }` - Tool call started
- `{ type: "toolcall_delta", contentIndex, delta, partial }` - Tool call JSON chunk
- `{ type: "toolcall_end", contentIndex, toolCall, partial }` - Tool call ended
For models with extended thinking, yield thinking events:
3. `{ type: "done", reason, message }` or `{ type: "error", reason, error }` - Stream ended
The `partial` field in each event contains the current `AssistantMessage` state. Update `output.content` as you receive data, then include `output` as the `partial`.
### Content Blocks
Add content blocks to `output.content` as they arrive:
```typescript
if (chunk.thinking) {
if (!thinkingStarted) {
yield { type: "thinking_start", contentIndex: thinkingIndex };
thinkingStarted = true;
}
yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking };
// Text block
output.content.push({ type: "text", text: "" });
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
// As text arrives
const block = output.content[contentIndex];
if (block.type === "text") {
block.text += delta;
stream.push({ type: "text_delta", contentIndex, delta, partial: output });
}
// When block completes
stream.push({ type: "text_end", contentIndex, content: block.text, partial: output });
```
### Tool Calls
For function calling support, yield tool call events:
Tool calls require accumulating JSON and parsing:
```typescript
if (chunk.tool_calls) {
for (const tc of chunk.tool_calls) {
if (tc.index !== currentToolIndex) {
if (currentToolIndex >= 0) {
yield {
type: "toolcall_end",
contentIndex: currentToolIndex,
toolCall: {
type: "toolCall",
id: currentToolId,
name: currentToolName,
arguments: JSON.parse(currentToolArgs)
}
};
}
currentToolIndex = tc.index;
currentToolId = tc.id;
currentToolName = tc.function.name;
currentToolArgs = "";
yield { type: "toolcall_start", contentIndex: tc.index };
}
if (tc.function.arguments) {
currentToolArgs += tc.function.arguments;
yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments };
}
}
}
// Start tool call
output.content.push({
type: "toolCall",
id: toolCallId,
name: toolName,
arguments: {}
});
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
// Accumulate JSON
let partialJson = "";
partialJson += jsonDelta;
try {
block.arguments = JSON.parse(partialJson);
} catch {}
stream.push({ type: "toolcall_delta", contentIndex, delta: jsonDelta, partial: output });
// Complete
stream.push({
type: "toolcall_end",
contentIndex,
toolCall: { type: "toolCall", id, name, arguments: block.arguments },
partial: output
});
```
### Usage and Cost
Update usage from API response and calculate cost:
```typescript
output.usage.input = response.usage.input_tokens;
output.usage.output = response.usage.output_tokens;
output.usage.cacheRead = response.usage.cache_read_tokens ?? 0;
output.usage.cacheWrite = response.usage.cache_write_tokens ?? 0;
output.usage.totalTokens = output.usage.input + output.usage.output +
output.usage.cacheRead + output.usage.cacheWrite;
calculateCost(model, output.usage);
```
### Registration
Register your stream function:
```typescript
pi.registerProvider("my-provider", {
baseUrl: "https://api.example.com",
apiKey: "MY_API_KEY",
api: "my-custom-api",
models: [...],
streamSimple: streamMyProvider
});
```
## Testing Your Implementation
Test your provider against the same test suites used by built-in providers. Copy and adapt these test files from [packages/ai/test/](https://github.com/badlogic/pi-mono/tree/main/packages/ai/test):
| Test | Purpose |
|------|---------|
| `stream.test.ts` | Basic streaming, text output |
| `tokens.test.ts` | Token counting and usage |
| `abort.test.ts` | AbortSignal handling |
| `empty.test.ts` | Empty/minimal responses |
| `context-overflow.test.ts` | Context window limits |
| `image-limits.test.ts` | Image input handling |
| `unicode-surrogate.test.ts` | Unicode edge cases |
| `tool-call-without-result.test.ts` | Tool call edge cases |
| `image-tool-result.test.ts` | Images in tool results |
| `total-tokens.test.ts` | Total token calculation |
| `cross-provider-handoff.test.ts` | Context handoff between providers |
Run tests with your provider/model pairs to verify compatibility.
## Config Reference
```typescript