mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 05:02:14 +00:00
Fix: compaction now handles branch_summary and custom_message entries
- Add getMessageFromEntry helper to extract AppMessage from any context-producing entry - Update findValidCutPoints to treat branch_summary/custom_message as valid cut points - Update findTurnStartIndex to recognize branch_summary/custom_message as turn starters - Update all message extraction loops to use getMessageFromEntry
This commit is contained in:
parent
754f55e3f6
commit
beb804cda0
1 changed files with 38 additions and 18 deletions
|
|
@ -9,7 +9,28 @@ import type { AppMessage } from "@mariozechner/pi-agent-core";
|
||||||
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
|
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
|
||||||
import { complete } from "@mariozechner/pi-ai";
|
import { complete } from "@mariozechner/pi-ai";
|
||||||
import { messageTransformer } from "./messages.js";
|
import { messageTransformer } from "./messages.js";
|
||||||
import type { CompactionEntry, SessionEntry } from "./session-manager.js";
|
import { type CompactionEntry, createSummaryMessage, type SessionEntry } from "./session-manager.js";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract AppMessage from an entry if it produces one.
|
||||||
|
* Returns null for entries that don't contribute to LLM context.
|
||||||
|
*/
|
||||||
|
function getMessageFromEntry(entry: SessionEntry): AppMessage | null {
|
||||||
|
if (entry.type === "message") {
|
||||||
|
return entry.message;
|
||||||
|
}
|
||||||
|
if (entry.type === "custom_message") {
|
||||||
|
return {
|
||||||
|
role: "user",
|
||||||
|
content: entry.content,
|
||||||
|
timestamp: new Date(entry.timestamp).getTime(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (entry.type === "branch_summary") {
|
||||||
|
return createSummaryMessage(entry.summary, entry.timestamp);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
|
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
|
||||||
export interface CompactionResult<T = unknown> {
|
export interface CompactionResult<T = unknown> {
|
||||||
|
|
@ -157,7 +178,10 @@ function findValidCutPoints(entries: SessionEntry[], startIndex: number, endInde
|
||||||
const cutPoints: number[] = [];
|
const cutPoints: number[] = [];
|
||||||
for (let i = startIndex; i < endIndex; i++) {
|
for (let i = startIndex; i < endIndex; i++) {
|
||||||
const entry = entries[i];
|
const entry = entries[i];
|
||||||
if (entry.type === "message") {
|
// branch_summary and custom_message are user-role messages, valid cut points
|
||||||
|
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||||
|
cutPoints.push(i);
|
||||||
|
} else if (entry.type === "message") {
|
||||||
const role = entry.message.role;
|
const role = entry.message.role;
|
||||||
// user, assistant, and bashExecution are valid cut points
|
// user, assistant, and bashExecution are valid cut points
|
||||||
// toolResult must stay with its preceding tool call
|
// toolResult must stay with its preceding tool call
|
||||||
|
|
@ -177,6 +201,10 @@ function findValidCutPoints(entries: SessionEntry[], startIndex: number, endInde
|
||||||
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
|
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
|
||||||
for (let i = entryIndex; i >= startIndex; i--) {
|
for (let i = entryIndex; i >= startIndex; i--) {
|
||||||
const entry = entries[i];
|
const entry = entries[i];
|
||||||
|
// branch_summary and custom_message are user-role messages, can start a turn
|
||||||
|
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
if (entry.type === "message") {
|
if (entry.type === "message") {
|
||||||
const role = entry.message.role;
|
const role = entry.message.role;
|
||||||
if (role === "user" || role === "bashExecution") {
|
if (role === "user" || role === "bashExecution") {
|
||||||
|
|
@ -382,19 +410,15 @@ export function prepareCompaction(entries: SessionEntry[], settings: CompactionS
|
||||||
// Messages to summarize (will be discarded after summary)
|
// Messages to summarize (will be discarded after summary)
|
||||||
const messagesToSummarize: AppMessage[] = [];
|
const messagesToSummarize: AppMessage[] = [];
|
||||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||||
const entry = entries[i];
|
const msg = getMessageFromEntry(entries[i]);
|
||||||
if (entry.type === "message") {
|
if (msg) messagesToSummarize.push(msg);
|
||||||
messagesToSummarize.push(entry.message);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Messages to keep (recent turns, kept after summary)
|
// Messages to keep (recent turns, kept after summary)
|
||||||
const messagesToKeep: AppMessage[] = [];
|
const messagesToKeep: AppMessage[] = [];
|
||||||
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
|
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
|
||||||
const entry = entries[i];
|
const msg = getMessageFromEntry(entries[i]);
|
||||||
if (entry.type === "message") {
|
if (msg) messagesToKeep.push(msg);
|
||||||
messagesToKeep.push(entry.message);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return { cutPoint, firstKeptEntryId, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart };
|
return { cutPoint, firstKeptEntryId, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart };
|
||||||
|
|
@ -460,10 +484,8 @@ export async function compact(
|
||||||
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
|
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
|
||||||
const historyMessages: AppMessage[] = [];
|
const historyMessages: AppMessage[] = [];
|
||||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||||
const entry = entries[i];
|
const msg = getMessageFromEntry(entries[i]);
|
||||||
if (entry.type === "message") {
|
if (msg) historyMessages.push(msg);
|
||||||
historyMessages.push(entry.message);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include previous summary if there was a compaction
|
// Include previous summary if there was a compaction
|
||||||
|
|
@ -480,10 +502,8 @@ export async function compact(
|
||||||
const turnPrefixMessages: AppMessage[] = [];
|
const turnPrefixMessages: AppMessage[] = [];
|
||||||
if (cutResult.isSplitTurn) {
|
if (cutResult.isSplitTurn) {
|
||||||
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
|
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
|
||||||
const entry = entries[i];
|
const msg = getMessageFromEntry(entries[i]);
|
||||||
if (entry.type === "message") {
|
if (msg) turnPrefixMessages.push(msg);
|
||||||
turnPrefixMessages.push(entry.message);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue