mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-20 15:01:24 +00:00
refactor(ai): Update API to support partial results on abort
- Anthropic, Google, and OpenAI Responses providers now return partial results when aborted - Restructured streaming to accumulate content blocks incrementally - Prevents submission of thinking/toolCall blocks from aborted completions in multi-turn conversations - Makes UI development easier by providing partial content even when requests are interrupted
This commit is contained in:
parent
5d5cd7955b
commit
bf1f410c2b
4 changed files with 244 additions and 280 deletions
|
|
@ -20,7 +20,6 @@ import type {
|
|||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
|
||||
export interface GoogleLLMOptions extends LLMOptions {
|
||||
|
|
@ -33,7 +32,7 @@ export interface GoogleLLMOptions extends LLMOptions {
|
|||
|
||||
export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||
private client: GoogleGenAI;
|
||||
private model: Model;
|
||||
private modelInfo: Model;
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
|
|
@ -45,14 +44,28 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
apiKey = process.env.GEMINI_API_KEY;
|
||||
}
|
||||
this.client = new GoogleGenAI({ apiKey });
|
||||
this.model = model;
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.model;
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
try {
|
||||
const contents = this.convertMessages(context.messages);
|
||||
|
||||
|
|
@ -82,7 +95,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
}
|
||||
|
||||
// Add thinking config if enabled and model supports it
|
||||
if (options?.thinking?.enabled && this.model.reasoning) {
|
||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||
|
|
@ -99,27 +112,15 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
|
||||
// Build the request parameters
|
||||
const params: GenerateContentParameters = {
|
||||
model: this.model.id,
|
||||
model: this.modelInfo.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
const stream = await this.client.models.generateContentStream(params);
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.model.id, provider: this.model.provider });
|
||||
|
||||
const blocks: AssistantMessage["content"] = [];
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
let usage: Usage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
let stopReason: StopReason = "stop";
|
||||
|
||||
// Process the stream
|
||||
for await (const chunk of stream) {
|
||||
// Extract parts from the chunk
|
||||
const candidate = chunk.candidates?.[0];
|
||||
|
|
@ -134,14 +135,12 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
// Save and finalize current block
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
|
||||
// Start new block
|
||||
|
|
@ -152,6 +151,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
currentBlock = { type: "text", text: "" };
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
}
|
||||
output.content.push(currentBlock);
|
||||
}
|
||||
|
||||
// Append content to current block
|
||||
|
|
@ -171,14 +171,12 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
|
||||
// Handle function calls
|
||||
if (part.functionCall) {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
|
|
@ -190,7 +188,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
name: part.functionCall.name || "",
|
||||
arguments: part.functionCall.args as Record<string, any>,
|
||||
};
|
||||
blocks.push(toolCall);
|
||||
output.content.push(toolCall);
|
||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
||||
}
|
||||
}
|
||||
|
|
@ -198,16 +196,16 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
|
||||
// Map finish reason
|
||||
if (candidate?.finishReason) {
|
||||
stopReason = this.mapStopReason(candidate.finishReason);
|
||||
output.stopReason = this.mapStopReason(candidate.finishReason);
|
||||
// Check if we have tool calls in blocks
|
||||
if (blocks.some((b) => b.type === "toolCall")) {
|
||||
stopReason = "toolUse";
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
// Capture usage metadata if available
|
||||
if (chunk.usageMetadata) {
|
||||
usage = {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
|
|
@ -221,47 +219,15 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Save final block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
|
||||
calculateCost(this.model, usage);
|
||||
|
||||
const output = {
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
provider: this.model.provider,
|
||||
model: this.model.id,
|
||||
usage,
|
||||
stopReason,
|
||||
} satisfies AssistantMessage;
|
||||
options?.onEvent?.({ type: "done", reason: stopReason, message: output });
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
} catch (error) {
|
||||
const output = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
provider: this.model.provider,
|
||||
model: this.model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
error: error instanceof Error ? error.message : JSON.stringify(error),
|
||||
} satisfies AssistantMessage;
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue