diff --git a/packages/coding-agent/src/core/compaction.ts b/packages/coding-agent/src/core/compaction.ts index b0156ec5..3fb771dd 100644 --- a/packages/coding-agent/src/core/compaction.ts +++ b/packages/coding-agent/src/core/compaction.ts @@ -9,7 +9,28 @@ import type { AppMessage } from "@mariozechner/pi-agent-core"; import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai"; import { complete } from "@mariozechner/pi-ai"; import { messageTransformer } from "./messages.js"; -import type { CompactionEntry, SessionEntry } from "./session-manager.js"; +import { type CompactionEntry, createSummaryMessage, type SessionEntry } from "./session-manager.js"; + +/** + * Extract AppMessage from an entry if it produces one. + * Returns null for entries that don't contribute to LLM context. + */ +function getMessageFromEntry(entry: SessionEntry): AppMessage | null { + if (entry.type === "message") { + return entry.message; + } + if (entry.type === "custom_message") { + return { + role: "user", + content: entry.content, + timestamp: new Date(entry.timestamp).getTime(), + }; + } + if (entry.type === "branch_summary") { + return createSummaryMessage(entry.summary, entry.timestamp); + } + return null; +} /** Result from compact() - SessionManager adds uuid/parentUuid when saving */ export interface CompactionResult { @@ -157,7 +178,10 @@ function findValidCutPoints(entries: SessionEntry[], startIndex: number, endInde const cutPoints: number[] = []; for (let i = startIndex; i < endIndex; i++) { const entry = entries[i]; - if (entry.type === "message") { + // branch_summary and custom_message are user-role messages, valid cut points + if (entry.type === "branch_summary" || entry.type === "custom_message") { + cutPoints.push(i); + } else if (entry.type === "message") { const role = entry.message.role; // user, assistant, and bashExecution are valid cut points // toolResult must stay with its preceding tool call @@ -177,6 +201,10 @@ function findValidCutPoints(entries: SessionEntry[], startIndex: number, endInde export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number { for (let i = entryIndex; i >= startIndex; i--) { const entry = entries[i]; + // branch_summary and custom_message are user-role messages, can start a turn + if (entry.type === "branch_summary" || entry.type === "custom_message") { + return i; + } if (entry.type === "message") { const role = entry.message.role; if (role === "user" || role === "bashExecution") { @@ -382,19 +410,15 @@ export function prepareCompaction(entries: SessionEntry[], settings: CompactionS // Messages to summarize (will be discarded after summary) const messagesToSummarize: AppMessage[] = []; for (let i = boundaryStart; i < historyEnd; i++) { - const entry = entries[i]; - if (entry.type === "message") { - messagesToSummarize.push(entry.message); - } + const msg = getMessageFromEntry(entries[i]); + if (msg) messagesToSummarize.push(msg); } // Messages to keep (recent turns, kept after summary) const messagesToKeep: AppMessage[] = []; for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) { - const entry = entries[i]; - if (entry.type === "message") { - messagesToKeep.push(entry.message); - } + const msg = getMessageFromEntry(entries[i]); + if (msg) messagesToKeep.push(msg); } return { cutPoint, firstKeptEntryId, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart }; @@ -460,10 +484,8 @@ export async function compact( const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex; const historyMessages: AppMessage[] = []; for (let i = boundaryStart; i < historyEnd; i++) { - const entry = entries[i]; - if (entry.type === "message") { - historyMessages.push(entry.message); - } + const msg = getMessageFromEntry(entries[i]); + if (msg) historyMessages.push(msg); } // Include previous summary if there was a compaction @@ -480,10 +502,8 @@ export async function compact( const turnPrefixMessages: AppMessage[] = []; if (cutResult.isSplitTurn) { for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) { - const entry = entries[i]; - if (entry.type === "message") { - turnPrefixMessages.push(entry.message); - } + const msg = getMessageFromEntry(entries[i]); + if (msg) turnPrefixMessages.push(msg); } }