mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 11:02:17 +00:00
docs: rewrite custom-provider.md with actual stream pattern
- Document stream.push/end pattern instead of yield - Link to existing provider implementations on GitHub - Add testing section with list of test files to run - Include proper content block and tool call examples
This commit is contained in:
parent
177c694406
commit
d4bd1a956b
1 changed files with 173 additions and 182 deletions
|
|
@ -14,6 +14,7 @@ Extensions can register custom model providers via `pi.registerProvider()`. This
|
|||
- [Register New Provider](#register-new-provider)
|
||||
- [OAuth Support](#oauth-support)
|
||||
- [Custom Streaming API](#custom-streaming-api)
|
||||
- [Testing Your Implementation](#testing-your-implementation)
|
||||
- [Config Reference](#config-reference)
|
||||
- [Model Definition Reference](#model-definition-reference)
|
||||
|
||||
|
|
@ -99,15 +100,6 @@ pi.registerProvider("my-llm", {
|
|||
},
|
||||
contextWindow: 200000,
|
||||
maxTokens: 16384
|
||||
},
|
||||
{
|
||||
id: "my-llm-small",
|
||||
name: "My LLM Small",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0.25, output: 1.25, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192
|
||||
}
|
||||
]
|
||||
});
|
||||
|
|
@ -171,17 +163,7 @@ import type { OAuthCredentials, OAuthLoginCallbacks } from "@mariozechner/pi-ai"
|
|||
pi.registerProvider("corporate-ai", {
|
||||
baseUrl: "https://ai.corp.com/v1",
|
||||
api: "openai-responses",
|
||||
models: [
|
||||
{
|
||||
id: "corp-claude",
|
||||
name: "Corporate Claude",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 16384
|
||||
}
|
||||
],
|
||||
models: [...],
|
||||
oauth: {
|
||||
name: "Corporate AI (SSO)",
|
||||
|
||||
|
|
@ -223,7 +205,6 @@ pi.registerProvider("corporate-ai", {
|
|||
|
||||
// Optional: modify models based on user's subscription
|
||||
modifyModels(models, credentials) {
|
||||
// e.g., update baseUrl based on user's region
|
||||
const region = decodeRegionFromToken(credentials.access);
|
||||
return models.map(m => ({
|
||||
...m,
|
||||
|
|
@ -267,193 +248,203 @@ interface OAuthCredentials {
|
|||
|
||||
## Custom Streaming API
|
||||
|
||||
For providers with non-standard APIs, implement `streamSimple`:
|
||||
For providers with non-standard APIs, implement `streamSimple`. Study the existing provider implementations before writing your own:
|
||||
|
||||
**Reference implementations:**
|
||||
- [anthropic.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/anthropic.ts) - Anthropic Messages API
|
||||
- [openai-completions.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-completions.ts) - OpenAI Chat Completions
|
||||
- [openai-responses.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/openai-responses.ts) - OpenAI Responses API
|
||||
- [google.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/google.ts) - Google Generative AI
|
||||
- [amazon-bedrock.ts](https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/providers/amazon-bedrock.ts) - AWS Bedrock
|
||||
|
||||
### Stream Pattern
|
||||
|
||||
All providers follow the same pattern:
|
||||
|
||||
```typescript
|
||||
import type {
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
Api
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
calculateCost,
|
||||
createAssistantMessageEventStream,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { createAssistantMessageEventStream } from "@mariozechner/pi-ai";
|
||||
|
||||
pi.registerProvider("custom-llm", {
|
||||
baseUrl: "https://api.custom-llm.com",
|
||||
apiKey: "CUSTOM_LLM_KEY",
|
||||
api: "custom-llm-api", // your custom API identifier
|
||||
models: [
|
||||
{
|
||||
id: "custom-model",
|
||||
name: "Custom Model",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 1.0, output: 2.0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 32000,
|
||||
maxTokens: 4096
|
||||
}
|
||||
],
|
||||
function streamMyProvider(
|
||||
model: Model<any>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions
|
||||
): AssistantMessageEventStream {
|
||||
const stream = createAssistantMessageEventStream();
|
||||
|
||||
streamSimple(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions
|
||||
): AssistantMessageEventStream {
|
||||
return createAssistantMessageEventStream(async function* (signal) {
|
||||
// Convert context to your API format
|
||||
const messages = context.messages.map(m => ({
|
||||
role: m.role,
|
||||
content: typeof m.content === "string"
|
||||
? m.content
|
||||
: m.content.filter(c => c.type === "text").map(c => c.text).join("")
|
||||
}));
|
||||
(async () => {
|
||||
// Initialize output message
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
// Make streaming request
|
||||
const response = await fetch(`${model.baseUrl}/chat`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${options?.apiKey}`,
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true
|
||||
}),
|
||||
signal
|
||||
try {
|
||||
// Push start event
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
// Make API request and process response...
|
||||
// Push content events as they arrive...
|
||||
|
||||
// Push done event
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: output.stopReason as "stop" | "length" | "toolUse",
|
||||
message: output
|
||||
});
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : String(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`API error: ${response.status}`);
|
||||
}
|
||||
|
||||
// Yield start event
|
||||
yield { type: "start" };
|
||||
|
||||
// Parse SSE stream
|
||||
const reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let contentIndex = 0;
|
||||
let textStarted = false;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data: ")) continue;
|
||||
const data = line.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
|
||||
const chunk = JSON.parse(data);
|
||||
const delta = chunk.choices?.[0]?.delta?.content;
|
||||
|
||||
if (delta) {
|
||||
if (!textStarted) {
|
||||
yield { type: "text_start", contentIndex };
|
||||
textStarted = true;
|
||||
}
|
||||
yield { type: "text_delta", contentIndex, delta };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (textStarted) {
|
||||
yield { type: "text_end", contentIndex };
|
||||
}
|
||||
|
||||
// Yield usage if available
|
||||
yield {
|
||||
type: "usage",
|
||||
usage: {
|
||||
input: 0, // fill from response if available
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
|
||||
}
|
||||
};
|
||||
|
||||
// Yield done
|
||||
yield { type: "done", reason: "stop" };
|
||||
});
|
||||
}
|
||||
});
|
||||
return stream;
|
||||
}
|
||||
```
|
||||
|
||||
### Event Types
|
||||
|
||||
Your generator must yield events in this order:
|
||||
Push events via `stream.push()` in this order:
|
||||
|
||||
1. `{ type: "start" }` - Stream started
|
||||
2. Content events (repeatable, in order):
|
||||
- `{ type: "text_start", contentIndex }` - Text block started
|
||||
- `{ type: "text_delta", contentIndex, delta }` - Text chunk
|
||||
- `{ type: "text_end", contentIndex }` - Text block ended
|
||||
- `{ type: "thinking_start", contentIndex }` - Thinking block started
|
||||
- `{ type: "thinking_delta", contentIndex, delta }` - Thinking chunk
|
||||
- `{ type: "thinking_end", contentIndex }` - Thinking block ended
|
||||
- `{ type: "toolcall_start", contentIndex }` - Tool call started
|
||||
- `{ type: "toolcall_delta", contentIndex, delta }` - Tool call JSON chunk
|
||||
- `{ type: "toolcall_end", contentIndex, toolCall }` - Tool call ended
|
||||
3. `{ type: "usage", usage }` - Token usage (optional but recommended)
|
||||
4. `{ type: "done", reason }` or `{ type: "error", error }` - Stream ended
|
||||
1. `{ type: "start", partial: output }` - Stream started
|
||||
|
||||
### Reasoning Support
|
||||
2. Content events (repeatable, track `contentIndex` for each block):
|
||||
- `{ type: "text_start", contentIndex, partial }` - Text block started
|
||||
- `{ type: "text_delta", contentIndex, delta, partial }` - Text chunk
|
||||
- `{ type: "text_end", contentIndex, content, partial }` - Text block ended
|
||||
- `{ type: "thinking_start", contentIndex, partial }` - Thinking started
|
||||
- `{ type: "thinking_delta", contentIndex, delta, partial }` - Thinking chunk
|
||||
- `{ type: "thinking_end", contentIndex, content, partial }` - Thinking ended
|
||||
- `{ type: "toolcall_start", contentIndex, partial }` - Tool call started
|
||||
- `{ type: "toolcall_delta", contentIndex, delta, partial }` - Tool call JSON chunk
|
||||
- `{ type: "toolcall_end", contentIndex, toolCall, partial }` - Tool call ended
|
||||
|
||||
For models with extended thinking, yield thinking events:
|
||||
3. `{ type: "done", reason, message }` or `{ type: "error", reason, error }` - Stream ended
|
||||
|
||||
The `partial` field in each event contains the current `AssistantMessage` state. Update `output.content` as you receive data, then include `output` as the `partial`.
|
||||
|
||||
### Content Blocks
|
||||
|
||||
Add content blocks to `output.content` as they arrive:
|
||||
|
||||
```typescript
|
||||
if (chunk.thinking) {
|
||||
if (!thinkingStarted) {
|
||||
yield { type: "thinking_start", contentIndex: thinkingIndex };
|
||||
thinkingStarted = true;
|
||||
}
|
||||
yield { type: "thinking_delta", contentIndex: thinkingIndex, delta: chunk.thinking };
|
||||
// Text block
|
||||
output.content.push({ type: "text", text: "" });
|
||||
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
|
||||
|
||||
// As text arrives
|
||||
const block = output.content[contentIndex];
|
||||
if (block.type === "text") {
|
||||
block.text += delta;
|
||||
stream.push({ type: "text_delta", contentIndex, delta, partial: output });
|
||||
}
|
||||
|
||||
// When block completes
|
||||
stream.push({ type: "text_end", contentIndex, content: block.text, partial: output });
|
||||
```
|
||||
|
||||
### Tool Calls
|
||||
|
||||
For function calling support, yield tool call events:
|
||||
Tool calls require accumulating JSON and parsing:
|
||||
|
||||
```typescript
|
||||
if (chunk.tool_calls) {
|
||||
for (const tc of chunk.tool_calls) {
|
||||
if (tc.index !== currentToolIndex) {
|
||||
if (currentToolIndex >= 0) {
|
||||
yield {
|
||||
type: "toolcall_end",
|
||||
contentIndex: currentToolIndex,
|
||||
toolCall: {
|
||||
type: "toolCall",
|
||||
id: currentToolId,
|
||||
name: currentToolName,
|
||||
arguments: JSON.parse(currentToolArgs)
|
||||
}
|
||||
};
|
||||
}
|
||||
currentToolIndex = tc.index;
|
||||
currentToolId = tc.id;
|
||||
currentToolName = tc.function.name;
|
||||
currentToolArgs = "";
|
||||
yield { type: "toolcall_start", contentIndex: tc.index };
|
||||
}
|
||||
if (tc.function.arguments) {
|
||||
currentToolArgs += tc.function.arguments;
|
||||
yield { type: "toolcall_delta", contentIndex: tc.index, delta: tc.function.arguments };
|
||||
}
|
||||
}
|
||||
}
|
||||
// Start tool call
|
||||
output.content.push({
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
arguments: {}
|
||||
});
|
||||
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||
|
||||
// Accumulate JSON
|
||||
let partialJson = "";
|
||||
partialJson += jsonDelta;
|
||||
try {
|
||||
block.arguments = JSON.parse(partialJson);
|
||||
} catch {}
|
||||
stream.push({ type: "toolcall_delta", contentIndex, delta: jsonDelta, partial: output });
|
||||
|
||||
// Complete
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex,
|
||||
toolCall: { type: "toolCall", id, name, arguments: block.arguments },
|
||||
partial: output
|
||||
});
|
||||
```
|
||||
|
||||
### Usage and Cost
|
||||
|
||||
Update usage from API response and calculate cost:
|
||||
|
||||
```typescript
|
||||
output.usage.input = response.usage.input_tokens;
|
||||
output.usage.output = response.usage.output_tokens;
|
||||
output.usage.cacheRead = response.usage.cache_read_tokens ?? 0;
|
||||
output.usage.cacheWrite = response.usage.cache_write_tokens ?? 0;
|
||||
output.usage.totalTokens = output.usage.input + output.usage.output +
|
||||
output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
```
|
||||
|
||||
### Registration
|
||||
|
||||
Register your stream function:
|
||||
|
||||
```typescript
|
||||
pi.registerProvider("my-provider", {
|
||||
baseUrl: "https://api.example.com",
|
||||
apiKey: "MY_API_KEY",
|
||||
api: "my-custom-api",
|
||||
models: [...],
|
||||
streamSimple: streamMyProvider
|
||||
});
|
||||
```
|
||||
|
||||
## Testing Your Implementation
|
||||
|
||||
Test your provider against the same test suites used by built-in providers. Copy and adapt these test files from [packages/ai/test/](https://github.com/badlogic/pi-mono/tree/main/packages/ai/test):
|
||||
|
||||
| Test | Purpose |
|
||||
|------|---------|
|
||||
| `stream.test.ts` | Basic streaming, text output |
|
||||
| `tokens.test.ts` | Token counting and usage |
|
||||
| `abort.test.ts` | AbortSignal handling |
|
||||
| `empty.test.ts` | Empty/minimal responses |
|
||||
| `context-overflow.test.ts` | Context window limits |
|
||||
| `image-limits.test.ts` | Image input handling |
|
||||
| `unicode-surrogate.test.ts` | Unicode edge cases |
|
||||
| `tool-call-without-result.test.ts` | Tool call edge cases |
|
||||
| `image-tool-result.test.ts` | Images in tool results |
|
||||
| `total-tokens.test.ts` | Total token calculation |
|
||||
| `cross-provider-handoff.test.ts` | Context handoff between providers |
|
||||
|
||||
Run tests with your provider/model pairs to verify compatibility.
|
||||
|
||||
## Config Reference
|
||||
|
||||
```typescript
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue