mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 12:00:15 +00:00
docs: rewrite custom-provider.md with actual stream pattern
- Document stream.push/end pattern instead of yield - Link to existing provider implementations on GitHub - Add testing section with list of test files to run - Include proper content block and tool call examples
This commit is contained in:
parent
177c694406
commit
d4bd1a956b
1 changed files with 173 additions and 182 deletions
|
|
@ -14,6 +14,7 @@ Extensions can register custom model providers via `pi.registerProvider()`. This
|
||||||
- [Register New Provider](#register-new-provider)
|
- [Register New Provider](#register-new-provider)
|
||||||
- [OAuth Support](#oauth-support)
|
- [OAuth Support](#oauth-support)
|
||||||
- [Custom Streaming API](#custom-streaming-api)
|
- [Custom Streaming API](#custom-streaming-api)
|
||||||
|
- [Testing Your Implementation](#testing-your-implementation)
|
||||||
- [Config Reference](#config-reference)
|
- [Config Reference](#config-reference)
|
||||||
- [Model Definition Reference](#model-definition-reference)
|
- [Model Definition Reference](#model-definition-reference)
|
||||||
|
|
||||||
|
|
@ -99,15 +100,6 @@ pi.registerProvider("my-llm", {
|
||||||
},
|
},
|
||||||
contextWindow: 200000,
|
contextWindow: 200000,
|
||||||
maxTokens: 16384
|
maxTokens: 16384
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "my-llm-small",
|
|
||||||
name: "My LLM Small",
|
|
||||||
reasoning: false,
|
|
||||||
input: ["text"],
|
|
||||||
cost: { input: 0.25, output: 1.25, cacheRead: 0, cacheWrite: 0 },
|
|
||||||
contextWindow: 128000,
|
|
||||||
maxTokens: 8192
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
@ -171,17 +163,7 @@ import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai"
|
||||||
pi.registerProvider("corporate-ai", {
|
pi.registerProvider("corporate-ai", {
|
||||||
baseUrl: "https://ai.corp.com/v1",
|
baseUrl: "https://ai.corp.com/v1",
|
||||||
api: "openai-responses",
|
api: "openai-responses",
|
||||||
models: [
|
models: [...],
|
||||||
{
|
|
||||||
id: "corp-claude",
|
|
||||||
name: "Corporate Claude",
|
|
||||||
reasoning: true,
|
|
||||||
input: ["text", "image"],
|
|
||||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
|
||||||
contextWindow: 200000,
|
|
||||||
maxTokens: 16384
|
|
||||||
}
|
|
||||||
],
|
|
||||||
oauth: {
|
oauth: {
|
||||||
name: "Corporate AI (SSO)",
|
name: "Corporate AI (SSO)",
|
||||||
|
|
||||||
|
|
@ -223,7 +205,6 @@ pi.registerProvider("corporate-ai", {
|
||||||
|
|
||||||
// Optional: modify models based on user's subscription
|
// Optional: modify models based on user's subscription
|
||||||
modifyModels(models, credentials) {
|
modifyModels(models, credentials) {
|
||||||
// e.g., update baseUrl based on user's region
|
|
||||||
const region = decodeRegionFromToken(credentials.access);
|
const region = decodeRegionFromToken(credentials.access);
|
||||||
return models.map(m => ({
|
return models.map(m => ({
|
||||||
...m,
|
...m,
|
||||||
|
|
@ -267,193 +248,203 @@ interface OAuthCredentials {
|
||||||
|
|
||||||
## Custom Streaming API
|
## Custom Streaming API
|
||||||
|
|
||||||
For providers with non-standard APIs, implement `streamSimple`:
|
For providers with non-standard APIs, implement `streamSimple`. Study the existing provider implementations before writing your own:
|
||||||
|
|
||||||
|
**Reference implementations:**
|
||||||
|
- [anthropic.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/anthropic.ts) - Anthropic Messages API
|
||||||
|
- [openai-completions.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-completions.ts) - OpenAI Chat Completions
|
||||||
|
- [openai-responses.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-responses.ts) - OpenAI Responses API
|
||||||
|
- [google.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/google.ts) - Google Generative AI
|
||||||
|
- [amazon-bedrock.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/amazon-bedrock.ts) - AWS Bedrock
|
||||||
|
|
||||||
|
### Stream Pattern
|
||||||
|
|
||||||
|
All providers follow the same pattern:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
import type {
|
import {
|
||||||
AssistantMessageEventStream,
|
type AssistantMessage,
|
||||||
Context,
|
type AssistantMessageEventStream,
|
||||||
Model,
|
type Context,
|
||||||
SimpleStreamOptions,
|
type Model,
|
||||||
Api
|
type SimpleStreamOptions,
|
||||||
|
calculateCost,
|
||||||
|
createAssistantMessageEventStream,
|
||||||
} from "@mariozechner/pi-ai";
|
} from "@mariozechner/pi-ai";
|
||||||
import { createAssistantMessageEventStream } from "@mariozechner/pi-ai";
|
|
||||||
|
|
||||||
pi.registerProvider("custom-llm", {
|
function streamMyProvider(
|
||||||
baseUrl: "https://api.custom-llm.com",
|
model: Model<any>,
|
||||||
apiKey: "CUSTOM_LLM_KEY",
|
|
||||||
api: "custom-llm-api", // your custom API identifier
|
|
||||||
models: [
|
|
||||||
{
|
|
||||||
id: "custom-model",
|
|
||||||
name: "Custom Model",
|
|
||||||
reasoning: false,
|
|
||||||
input: ["text"],
|
|
||||||
cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 },
|
|
||||||
contextWindow: 32000,
|
|
||||||
maxTokens: 4096
|
|
||||||
}
|
|
||||||
],
|
|
||||||
|
|
||||||
streamSimple(
|
|
||||||
model: Model<Api>,
|
|
||||||
context: Context,
|
context: Context,
|
||||||
options?: SimpleStreamOptions
|
options?: SimpleStreamOptions
|
||||||
): AssistantMessageEventStream {
|
): AssistantMessageEventStream {
|
||||||
return createAssistantMessageEventStream(async function* (signal) {
|
const stream = createAssistantMessageEventStream();
|
||||||
// Convert context to your API format
|
|
||||||
const messages = context.messages.map(m => ({
|
|
||||||
role: m.role,
|
|
||||||
content: typeof m.content === "string"
|
|
||||||
? m.content
|
|
||||||
: m.content.filter(c => c.type === "text").map(c => c.text).join("")
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Make streaming request
|
(async () => {
|
||||||
const response = await fetch(`${model.baseUrl}/chat`, {
|
// Initialize output message
|
||||||
method: "POST",
|
const output: AssistantMessage = {
|
||||||
headers: {
|
role: "assistant",
|
||||||
"Authorization": `Bearer ${options?.apiKey}`,
|
content: [],
|
||||||
"Content-Type": "application/json"
|
api: model.api,
|
||||||
},
|
provider: model.provider,
|
||||||
body: JSON.stringify({
|
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages,
|
|
||||||
stream: true
|
|
||||||
}),
|
|
||||||
signal
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`API error: ${response.status}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Yield start event
|
|
||||||
yield { type: "start" };
|
|
||||||
|
|
||||||
// Parse SSE stream
|
|
||||||
const reader = response.body!.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let buffer = "";
|
|
||||||
let contentIndex = 0;
|
|
||||||
let textStarted = false;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const { done, value } = await reader.read();
|
|
||||||
if (done) break;
|
|
||||||
|
|
||||||
buffer += decoder.decode(value, { stream: true });
|
|
||||||
const lines = buffer.split("\n");
|
|
||||||
buffer = lines.pop() ?? "";
|
|
||||||
|
|
||||||
for (const line of lines) {
|
|
||||||
if (!line.startsWith("data: ")) continue;
|
|
||||||
const data = line.slice(6);
|
|
||||||
if (data === "[DONE]") continue;
|
|
||||||
|
|
||||||
const chunk = JSON.parse(data);
|
|
||||||
const delta = chunk.choices?.[0]?.delta?.content;
|
|
||||||
|
|
||||||
if (delta) {
|
|
||||||
if (!textStarted) {
|
|
||||||
yield { type: "text_start", contentIndex };
|
|
||||||
textStarted = true;
|
|
||||||
}
|
|
||||||
yield { type: "text_delta", contentIndex, delta };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (textStarted) {
|
|
||||||
yield { type: "text_end", contentIndex };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Yield usage if available
|
|
||||||
yield {
|
|
||||||
type: "usage",
|
|
||||||
usage: {
|
usage: {
|
||||||
input: 0, // fill from response if available
|
input: 0,
|
||||||
output: 0,
|
output: 0,
|
||||||
cacheRead: 0,
|
cacheRead: 0,
|
||||||
cacheWrite: 0,
|
cacheWrite: 0,
|
||||||
totalTokens: 0,
|
totalTokens: 0,
|
||||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
}
|
},
|
||||||
|
stopReason: "stop",
|
||||||
|
timestamp: Date.now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Yield done
|
try {
|
||||||
yield { type: "done", reason: "stop" };
|
// Push start event
|
||||||
|
stream.push({ type: "start", partial: output });
|
||||||
|
|
||||||
|
// Make API request and process response...
|
||||||
|
// Push content events as they arrive...
|
||||||
|
|
||||||
|
// Push done event
|
||||||
|
stream.push({
|
||||||
|
type: "done",
|
||||||
|
reason: output.stopReason as "stop" | "length" | "toolUse",
|
||||||
|
message: output
|
||||||
});
|
});
|
||||||
|
stream.end();
|
||||||
|
} catch (error) {
|
||||||
|
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||||
|
output.errorMessage = error instanceof Error ? error.message : String(error);
|
||||||
|
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||||
|
stream.end();
|
||||||
}
|
}
|
||||||
});
|
})();
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Event Types
|
### Event Types
|
||||||
|
|
||||||
Your generator must yield events in this order:
|
Push events via `stream.push()` in this order:
|
||||||
|
|
||||||
1. `{ type: "start" }` - Stream started
|
1. `{ type: "start", partial: output }` - Stream started
|
||||||
2. Content events (repeatable, in order):
|
|
||||||
- `{ type: "text_start", contentIndex }` - Text block started
|
|
||||||
- `{ type: "text_delta", contentIndex, delta }` - Text chunk
|
|
||||||
- `{ type: "text_end", contentIndex }` - Text block ended
|
|
||||||
- `{ type: "thinking_start", contentIndex }` - Thinking block started
|
|
||||||
- `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk
|
|
||||||
- `{ type: "thinking_end", contentIndex }` - Thinking block ended
|
|
||||||
- `{ type: "toolcall_start", contentIndex }` - Tool call started
|
|
||||||
- `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk
|
|
||||||
- `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended
|
|
||||||
3. `{ type: "usage", usage }` - Token usage (optional but recommended)
|
|
||||||
4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended
|
|
||||||
|
|
||||||
### Reasoning Support
|
2. Content events (repeatable, track `contentIndex` for each block):
|
||||||
|
- `{ type: "text_start", contentIndex, partial }` - Text block started
|
||||||
|
- `{ type: "text_delta", contentIndex, delta, partial }` - Text chunk
|
||||||
|
- `{ type: "text_end", contentIndex, content, partial }` - Text block ended
|
||||||
|
- `{ type: "thinking_start", contentIndex, partial }` - Thinking started
|
||||||
|
- `{ type: "thinking_delta", contentIndex, delta, partial }` - Thinking chunk
|
||||||
|
- `{ type: "thinking_end", contentIndex, content, partial }` - Thinking ended
|
||||||
|
- `{ type: "toolcall_start", contentIndex, partial }` - Tool call started
|
||||||
|
- `{ type: "toolcall_delta", contentIndex, delta, partial }` - Tool call JSON chunk
|
||||||
|
- `{ type: "toolcall_end", contentIndex, toolCall, partial }` - Tool call ended
|
||||||
|
|
||||||
For models with extended thinking, yield thinking events:
|
3. `{ type: "done", reason, message }` or `{ type: "error", reason, error }` - Stream ended
|
||||||
|
|
||||||
|
The `partial` field in each event contains the current `AssistantMessage` state. Update `output.content` as you receive data, then include `output` as the `partial`.
|
||||||
|
|
||||||
|
### Content Blocks
|
||||||
|
|
||||||
|
Add content blocks to `output.content` as they arrive:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
if (chunk.thinking) {
|
// Text block
|
||||||
if (!thinkingStarted) {
|
output.content.push({ type: "text", text: "" });
|
||||||
yield { type: "thinking_start", contentIndex: thinkingIndex };
|
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
|
||||||
thinkingStarted = true;
|
|
||||||
}
|
// As text arrives
|
||||||
yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking };
|
const block = output.content[contentIndex];
|
||||||
|
if (block.type === "text") {
|
||||||
|
block.text += delta;
|
||||||
|
stream.push({ type: "text_delta", contentIndex, delta, partial: output });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When block completes
|
||||||
|
stream.push({ type: "text_end", contentIndex, content: block.text, partial: output });
|
||||||
```
|
```
|
||||||
|
|
||||||
### Tool Calls
|
### Tool Calls
|
||||||
|
|
||||||
For function calling support, yield tool call events:
|
Tool calls require accumulating JSON and parsing:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
if (chunk.tool_calls) {
|
// Start tool call
|
||||||
for (const tc of chunk.tool_calls) {
|
output.content.push({
|
||||||
if (tc.index !== currentToolIndex) {
|
|
||||||
if (currentToolIndex >= 0) {
|
|
||||||
yield {
|
|
||||||
type: "toolcall_end",
|
|
||||||
contentIndex: currentToolIndex,
|
|
||||||
toolCall: {
|
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
id: currentToolId,
|
id: toolCallId,
|
||||||
name: currentToolName,
|
name: toolName,
|
||||||
arguments: JSON.parse(currentToolArgs)
|
arguments: {}
|
||||||
}
|
});
|
||||||
};
|
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||||
}
|
|
||||||
currentToolIndex = tc.index;
|
// Accumulate JSON
|
||||||
currentToolId = tc.id;
|
let partialJson = "";
|
||||||
currentToolName = tc.function.name;
|
partialJson += jsonDelta;
|
||||||
currentToolArgs = "";
|
try {
|
||||||
yield { type: "toolcall_start", contentIndex: tc.index };
|
block.arguments = JSON.parse(partialJson);
|
||||||
}
|
} catch {}
|
||||||
if (tc.function.arguments) {
|
stream.push({ type: "toolcall_delta", contentIndex, delta: jsonDelta, partial: output });
|
||||||
currentToolArgs += tc.function.arguments;
|
|
||||||
yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments };
|
// Complete
|
||||||
}
|
stream.push({
|
||||||
}
|
type: "toolcall_end",
|
||||||
}
|
contentIndex,
|
||||||
|
toolCall: { type: "toolCall", id, name, arguments: block.arguments },
|
||||||
|
partial: output
|
||||||
|
});
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Usage and Cost
|
||||||
|
|
||||||
|
Update usage from API response and calculate cost:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
output.usage.input = response.usage.input_tokens;
|
||||||
|
output.usage.output = response.usage.output_tokens;
|
||||||
|
output.usage.cacheRead = response.usage.cache_read_tokens ?? 0;
|
||||||
|
output.usage.cacheWrite = response.usage.cache_write_tokens ?? 0;
|
||||||
|
output.usage.totalTokens = output.usage.input + output.usage.output +
|
||||||
|
output.usage.cacheRead + output.usage.cacheWrite;
|
||||||
|
calculateCost(model, output.usage);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registration
|
||||||
|
|
||||||
|
Register your stream function:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
pi.registerProvider("my-provider", {
|
||||||
|
baseUrl: "https://api.example.com",
|
||||||
|
apiKey: "MY_API_KEY",
|
||||||
|
api: "my-custom-api",
|
||||||
|
models: [...],
|
||||||
|
streamSimple: streamMyProvider
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Your Implementation
|
||||||
|
|
||||||
|
Test your provider against the same test suites used by built-in providers. Copy and adapt these test files from [packages/ai/test/](https://github.com/badlogic/pi-mono/tree/main/packages/ai/test):
|
||||||
|
|
||||||
|
| Test | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `stream.test.ts` | Basic streaming, text output |
|
||||||
|
| `tokens.test.ts` | Token counting and usage |
|
||||||
|
| `abort.test.ts` | AbortSignal handling |
|
||||||
|
| `empty.test.ts` | Empty/minimal responses |
|
||||||
|
| `context-overflow.test.ts` | Context window limits |
|
||||||
|
| `image-limits.test.ts` | Image input handling |
|
||||||
|
| `unicode-surrogate.test.ts` | Unicode edge cases |
|
||||||
|
| `tool-call-without-result.test.ts` | Tool call edge cases |
|
||||||
|
| `image-tool-result.test.ts` | Images in tool results |
|
||||||
|
| `total-tokens.test.ts` | Total token calculation |
|
||||||
|
| `cross-provider-handoff.test.ts` | Context handoff between providers |
|
||||||
|
|
||||||
|
Run tests with your provider/model pairs to verify compatibility.
|
||||||
|
|
||||||
## Config Reference
|
## Config Reference
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue