mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 06:04:51 +00:00
Change branch() to use entryId instead of entryIndex
- AgentSession.branch(entryId: string) now takes entry ID - SessionBeforeBranchEvent.entryId replaces entryIndex - getUserMessagesForBranching() returns entryId - Update RPC types and client - Update UserMessageSelectorComponent - Update hook examples and tests - Update docs (hooks.md, sdk.md)
This commit is contained in:
parent
027d39aa33
commit
8e1e99ca05
12 changed files with 64 additions and 50 deletions
|
|
@ -195,7 +195,7 @@ Fired when branching via `/branch`.
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
pi.on("session_before_branch", async (event, ctx) => {
|
pi.on("session_before_branch", async (event, ctx) => {
|
||||||
// event.entryIndex - entry index being branched from
|
// event.entryId - ID of the entry being branched from
|
||||||
|
|
||||||
return { cancel: true }; // Cancel branch
|
return { cancel: true }; // Cancel branch
|
||||||
// OR
|
// OR
|
||||||
|
|
@ -634,15 +634,23 @@ export default function (pi: HookAPI) {
|
||||||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||||
|
|
||||||
export default function (pi: HookAPI) {
|
export default function (pi: HookAPI) {
|
||||||
const checkpoints = new Map<number, string>();
|
const checkpoints = new Map<string, string>();
|
||||||
|
let currentEntryId: string | undefined;
|
||||||
|
|
||||||
pi.on("turn_start", async (event) => {
|
pi.on("tool_result", async (_event, ctx) => {
|
||||||
|
const leaf = ctx.sessionManager.getLeafEntry();
|
||||||
|
if (leaf) currentEntryId = leaf.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
pi.on("turn_start", async () => {
|
||||||
const { stdout } = await pi.exec("git", ["stash", "create"]);
|
const { stdout } = await pi.exec("git", ["stash", "create"]);
|
||||||
if (stdout.trim()) checkpoints.set(event.turnIndex, stdout.trim());
|
if (stdout.trim() && currentEntryId) {
|
||||||
|
checkpoints.set(currentEntryId, stdout.trim());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
pi.on("session_before_branch", async (event, ctx) => {
|
pi.on("session_before_branch", async (event, ctx) => {
|
||||||
const ref = checkpoints.get(event.entryIndex);
|
const ref = checkpoints.get(event.entryId);
|
||||||
if (!ref || !ctx.hasUI) return;
|
if (!ref || !ctx.hasUI) return;
|
||||||
|
|
||||||
const ok = await ctx.ui.confirm("Restore?", "Restore code to checkpoint?");
|
const ok = await ctx.ui.confirm("Restore?", "Restore code to checkpoint?");
|
||||||
|
|
|
||||||
|
|
@ -99,11 +99,12 @@ interface AgentSession {
|
||||||
isStreaming: boolean;
|
isStreaming: boolean;
|
||||||
|
|
||||||
// Session management
|
// Session management
|
||||||
newSession(): Promise<boolean>; // Returns false if cancelled by hook
|
reset(): Promise<boolean>; // Returns false if cancelled by hook
|
||||||
switchSession(sessionPath: string): Promise<boolean>;
|
switchSession(sessionPath: string): Promise<boolean>;
|
||||||
|
|
||||||
// Branching (tree-based)
|
// Branching
|
||||||
branch(entryId: string): Promise<{ cancelled: boolean }>;
|
branch(entryId: string): Promise<{ selectedText: string; cancelled: boolean }>; // Creates new session file
|
||||||
|
navigateTree(targetId: string, options?: { summarize?: boolean }): Promise<{ editorText?: string; cancelled: boolean }>; // In-place navigation
|
||||||
|
|
||||||
// Hook message injection
|
// Hook message injection
|
||||||
sendHookMessage(message: HookMessage, triggerTurn?: boolean): void;
|
sendHookMessage(message: HookMessage, triggerTurn?: boolean): void;
|
||||||
|
|
@ -400,10 +401,10 @@ const { session } = await createAgentSession({
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
import { Type } from "@sinclair/typebox";
|
import { Type } from "@sinclair/typebox";
|
||||||
import { createAgentSession, discoverCustomTools, type CustomAgentTool } from "@mariozechner/pi-coding-agent";
|
import { createAgentSession, discoverCustomTools, type CustomTool } from "@mariozechner/pi-coding-agent";
|
||||||
|
|
||||||
// Inline custom tool
|
// Inline custom tool
|
||||||
const myTool: CustomAgentTool = {
|
const myTool: CustomTool = {
|
||||||
name: "my_tool",
|
name: "my_tool",
|
||||||
label: "My Tool",
|
label: "My Tool",
|
||||||
description: "Does something useful",
|
description: "Does something useful",
|
||||||
|
|
@ -793,7 +794,7 @@ import {
|
||||||
readTool,
|
readTool,
|
||||||
bashTool,
|
bashTool,
|
||||||
type HookFactory,
|
type HookFactory,
|
||||||
type CustomAgentTool,
|
type CustomTool,
|
||||||
} from "@mariozechner/pi-coding-agent";
|
} from "@mariozechner/pi-coding-agent";
|
||||||
|
|
||||||
// Set up auth storage (custom location)
|
// Set up auth storage (custom location)
|
||||||
|
|
@ -816,7 +817,7 @@ const auditHook: HookFactory = (api) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Inline tool
|
// Inline tool
|
||||||
const statusTool: CustomAgentTool = {
|
const statusTool: CustomTool = {
|
||||||
name: "status",
|
name: "status",
|
||||||
label: "Status",
|
label: "Status",
|
||||||
description: "Get system status",
|
description: "Get system status",
|
||||||
|
|
@ -932,7 +933,7 @@ createGrepTool, createFindTool, createLsTool
|
||||||
// Types
|
// Types
|
||||||
type CreateAgentSessionOptions
|
type CreateAgentSessionOptions
|
||||||
type CreateAgentSessionResult
|
type CreateAgentSessionResult
|
||||||
type CustomAgentTool
|
type CustomTool
|
||||||
type HookFactory
|
type HookFactory
|
||||||
type Skill
|
type Skill
|
||||||
type FileSlashCommand
|
type FileSlashCommand
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ export default function (pi: HookAPI) {
|
||||||
pi.on("session_before_branch", async (event, ctx) => {
|
pi.on("session_before_branch", async (event, ctx) => {
|
||||||
if (!ctx.hasUI) return;
|
if (!ctx.hasUI) return;
|
||||||
|
|
||||||
const choice = await ctx.ui.select(`Branch from turn ${event.entryIndex}?`, [
|
const choice = await ctx.ui.select(`Branch from entry ${event.entryId.slice(0, 8)}?`, [
|
||||||
"Yes, create branch",
|
"Yes, create branch",
|
||||||
"No, stay in current session",
|
"No, stay in current session",
|
||||||
]);
|
]);
|
||||||
|
|
|
||||||
|
|
@ -8,19 +8,26 @@
|
||||||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||||
|
|
||||||
export default function (pi: HookAPI) {
|
export default function (pi: HookAPI) {
|
||||||
const checkpoints = new Map<number, string>();
|
const checkpoints = new Map<string, string>();
|
||||||
|
let currentEntryId: string | undefined;
|
||||||
|
|
||||||
pi.on("turn_start", async (event) => {
|
// Track the current entry ID when user messages are saved
|
||||||
|
pi.on("tool_result", async (_event, ctx) => {
|
||||||
|
const leaf = ctx.sessionManager.getLeafEntry();
|
||||||
|
if (leaf) currentEntryId = leaf.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
pi.on("turn_start", async () => {
|
||||||
// Create a git stash entry before LLM makes changes
|
// Create a git stash entry before LLM makes changes
|
||||||
const { stdout } = await pi.exec("git", ["stash", "create"]);
|
const { stdout } = await pi.exec("git", ["stash", "create"]);
|
||||||
const ref = stdout.trim();
|
const ref = stdout.trim();
|
||||||
if (ref) {
|
if (ref && currentEntryId) {
|
||||||
checkpoints.set(event.turnIndex, ref);
|
checkpoints.set(currentEntryId, ref);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
pi.on("session_before_branch", async (event, ctx) => {
|
pi.on("session_before_branch", async (event, ctx) => {
|
||||||
const ref = checkpoints.get(event.entryIndex);
|
const ref = checkpoints.get(event.entryId);
|
||||||
if (!ref) return;
|
if (!ref) return;
|
||||||
|
|
||||||
if (!ctx.hasUI) {
|
if (!ctx.hasUI) {
|
||||||
|
|
|
||||||
|
|
@ -1498,21 +1498,20 @@ export class AgentSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a branch from a specific entry index.
|
* Create a branch from a specific entry.
|
||||||
* Emits before_branch/branch session events to hooks.
|
* Emits before_branch/branch session events to hooks.
|
||||||
*
|
*
|
||||||
* @param entryIndex Index into session entries to branch from
|
* @param entryId ID of the entry to branch from
|
||||||
* @returns Object with:
|
* @returns Object with:
|
||||||
* - selectedText: The text of the selected user message (for editor pre-fill)
|
* - selectedText: The text of the selected user message (for editor pre-fill)
|
||||||
* - cancelled: True if a hook cancelled the branch
|
* - cancelled: True if a hook cancelled the branch
|
||||||
*/
|
*/
|
||||||
async branch(entryIndex: number): Promise<{ selectedText: string; cancelled: boolean }> {
|
async branch(entryId: string): Promise<{ selectedText: string; cancelled: boolean }> {
|
||||||
const previousSessionFile = this.sessionFile;
|
const previousSessionFile = this.sessionFile;
|
||||||
const entries = this.sessionManager.getEntries();
|
const selectedEntry = this.sessionManager.getEntry(entryId);
|
||||||
const selectedEntry = entries[entryIndex];
|
|
||||||
|
|
||||||
if (!selectedEntry || selectedEntry.type !== "message" || selectedEntry.message.role !== "user") {
|
if (!selectedEntry || selectedEntry.type !== "message" || selectedEntry.message.role !== "user") {
|
||||||
throw new Error("Invalid entry index for branching");
|
throw new Error("Invalid entry ID for branching");
|
||||||
}
|
}
|
||||||
|
|
||||||
const selectedText = this._extractUserMessageText(selectedEntry.message.content);
|
const selectedText = this._extractUserMessageText(selectedEntry.message.content);
|
||||||
|
|
@ -1523,7 +1522,7 @@ export class AgentSession {
|
||||||
if (this._hookRunner?.hasHandlers("session_before_branch")) {
|
if (this._hookRunner?.hasHandlers("session_before_branch")) {
|
||||||
const result = (await this._hookRunner.emit({
|
const result = (await this._hookRunner.emit({
|
||||||
type: "session_before_branch",
|
type: "session_before_branch",
|
||||||
entryIndex: entryIndex,
|
entryId,
|
||||||
})) as SessionBeforeBranchResult | undefined;
|
})) as SessionBeforeBranchResult | undefined;
|
||||||
|
|
||||||
if (result?.cancel) {
|
if (result?.cancel) {
|
||||||
|
|
@ -1729,18 +1728,17 @@ export class AgentSession {
|
||||||
/**
|
/**
|
||||||
* Get all user messages from session for branch selector.
|
* Get all user messages from session for branch selector.
|
||||||
*/
|
*/
|
||||||
getUserMessagesForBranching(): Array<{ entryIndex: number; text: string }> {
|
getUserMessagesForBranching(): Array<{ entryId: string; text: string }> {
|
||||||
const entries = this.sessionManager.getEntries();
|
const entries = this.sessionManager.getEntries();
|
||||||
const result: Array<{ entryIndex: number; text: string }> = [];
|
const result: Array<{ entryId: string; text: string }> = [];
|
||||||
|
|
||||||
for (let i = 0; i < entries.length; i++) {
|
for (const entry of entries) {
|
||||||
const entry = entries[i];
|
|
||||||
if (entry.type !== "message") continue;
|
if (entry.type !== "message") continue;
|
||||||
if (entry.message.role !== "user") continue;
|
if (entry.message.role !== "user") continue;
|
||||||
|
|
||||||
const text = this._extractUserMessageText(entry.message.content);
|
const text = this._extractUserMessageText(entry.message.content);
|
||||||
if (text) {
|
if (text) {
|
||||||
result.push({ entryIndex: i, text });
|
result.push({ entryId: entry.id, text });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,8 +121,8 @@ export interface SessionNewEvent {
|
||||||
/** Fired before branching a session (can be cancelled) */
|
/** Fired before branching a session (can be cancelled) */
|
||||||
export interface SessionBeforeBranchEvent {
|
export interface SessionBeforeBranchEvent {
|
||||||
type: "session_before_branch";
|
type: "session_before_branch";
|
||||||
/** Index of the entry in the session (SessionManager.getEntries()) to branch from */
|
/** ID of the entry to branch from */
|
||||||
entryIndex: number;
|
entryId: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fired after branching a session */
|
/** Fired after branching a session */
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ import { theme } from "../theme/theme.js";
|
||||||
import { DynamicBorder } from "./dynamic-border.js";
|
import { DynamicBorder } from "./dynamic-border.js";
|
||||||
|
|
||||||
interface UserMessageItem {
|
interface UserMessageItem {
|
||||||
index: number; // Index in the full messages array
|
id: string; // Entry ID in the session
|
||||||
text: string; // The message text
|
text: string; // The message text
|
||||||
timestamp?: string; // Optional timestamp if available
|
timestamp?: string; // Optional timestamp if available
|
||||||
}
|
}
|
||||||
|
|
@ -25,7 +25,7 @@ interface UserMessageItem {
|
||||||
class UserMessageList implements Component {
|
class UserMessageList implements Component {
|
||||||
private messages: UserMessageItem[] = [];
|
private messages: UserMessageItem[] = [];
|
||||||
private selectedIndex: number = 0;
|
private selectedIndex: number = 0;
|
||||||
public onSelect?: (messageIndex: number) => void;
|
public onSelect?: (entryId: string) => void;
|
||||||
public onCancel?: () => void;
|
public onCancel?: () => void;
|
||||||
private maxVisible: number = 10; // Max messages visible
|
private maxVisible: number = 10; // Max messages visible
|
||||||
|
|
||||||
|
|
@ -101,7 +101,7 @@ class UserMessageList implements Component {
|
||||||
else if (isEnter(keyData)) {
|
else if (isEnter(keyData)) {
|
||||||
const selected = this.messages[this.selectedIndex];
|
const selected = this.messages[this.selectedIndex];
|
||||||
if (selected && this.onSelect) {
|
if (selected && this.onSelect) {
|
||||||
this.onSelect(selected.index);
|
this.onSelect(selected.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Escape - cancel
|
// Escape - cancel
|
||||||
|
|
@ -125,7 +125,7 @@ class UserMessageList implements Component {
|
||||||
export class UserMessageSelectorComponent extends Container {
|
export class UserMessageSelectorComponent extends Container {
|
||||||
private messageList: UserMessageList;
|
private messageList: UserMessageList;
|
||||||
|
|
||||||
constructor(messages: UserMessageItem[], onSelect: (messageIndex: number) => void, onCancel: () => void) {
|
constructor(messages: UserMessageItem[], onSelect: (entryId: string) => void, onCancel: () => void) {
|
||||||
super();
|
super();
|
||||||
|
|
||||||
// Add header
|
// Add header
|
||||||
|
|
|
||||||
|
|
@ -1570,9 +1570,9 @@ export class InteractiveMode {
|
||||||
|
|
||||||
this.showSelector((done) => {
|
this.showSelector((done) => {
|
||||||
const selector = new UserMessageSelectorComponent(
|
const selector = new UserMessageSelectorComponent(
|
||||||
userMessages.map((m) => ({ index: m.entryIndex, text: m.text })),
|
userMessages.map((m) => ({ id: m.entryId, text: m.text })),
|
||||||
async (entryIndex) => {
|
async (entryId) => {
|
||||||
const result = await this.session.branch(entryIndex);
|
const result = await this.session.branch(entryId);
|
||||||
if (result.cancelled) {
|
if (result.cancelled) {
|
||||||
// Hook cancelled the branch
|
// Hook cancelled the branch
|
||||||
done();
|
done();
|
||||||
|
|
|
||||||
|
|
@ -326,17 +326,17 @@ export class RpcClient {
|
||||||
* Branch from a specific message.
|
* Branch from a specific message.
|
||||||
* @returns Object with `text` (the message text) and `cancelled` (if hook cancelled)
|
* @returns Object with `text` (the message text) and `cancelled` (if hook cancelled)
|
||||||
*/
|
*/
|
||||||
async branch(entryIndex: number): Promise<{ text: string; cancelled: boolean }> {
|
async branch(entryId: string): Promise<{ text: string; cancelled: boolean }> {
|
||||||
const response = await this.send({ type: "branch", entryIndex });
|
const response = await this.send({ type: "branch", entryId });
|
||||||
return this.getData(response);
|
return this.getData(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get messages available for branching.
|
* Get messages available for branching.
|
||||||
*/
|
*/
|
||||||
async getBranchMessages(): Promise<Array<{ entryIndex: number; text: string }>> {
|
async getBranchMessages(): Promise<Array<{ entryId: string; text: string }>> {
|
||||||
const response = await this.send({ type: "get_branch_messages" });
|
const response = await this.send({ type: "get_branch_messages" });
|
||||||
return this.getData<{ messages: Array<{ entryIndex: number; text: string }> }>(response).messages;
|
return this.getData<{ messages: Array<{ entryId: string; text: string }> }>(response).messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -347,7 +347,7 @@ export async function runRpcMode(session: AgentSession): Promise<never> {
|
||||||
}
|
}
|
||||||
|
|
||||||
case "branch": {
|
case "branch": {
|
||||||
const result = await session.branch(command.entryIndex);
|
const result = await session.branch(command.entryId);
|
||||||
return success(id, "branch", { text: result.selectedText, cancelled: result.cancelled });
|
return success(id, "branch", { text: result.selectedText, cancelled: result.cancelled });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ export type RpcCommand =
|
||||||
| { id?: string; type: "get_session_stats" }
|
| { id?: string; type: "get_session_stats" }
|
||||||
| { id?: string; type: "export_html"; outputPath?: string }
|
| { id?: string; type: "export_html"; outputPath?: string }
|
||||||
| { id?: string; type: "switch_session"; sessionPath: string }
|
| { id?: string; type: "switch_session"; sessionPath: string }
|
||||||
| { id?: string; type: "branch"; entryIndex: number }
|
| { id?: string; type: "branch"; entryId: string }
|
||||||
| { id?: string; type: "get_branch_messages" }
|
| { id?: string; type: "get_branch_messages" }
|
||||||
| { id?: string; type: "get_last_assistant_text" }
|
| { id?: string; type: "get_last_assistant_text" }
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ export type RpcResponse =
|
||||||
type: "response";
|
type: "response";
|
||||||
command: "get_branch_messages";
|
command: "get_branch_messages";
|
||||||
success: true;
|
success: true;
|
||||||
data: { messages: Array<{ entryIndex: number; text: string }> };
|
data: { messages: Array<{ entryId: string; text: string }> };
|
||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
id?: string;
|
id?: string;
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => {
|
||||||
expect(userMessages[0].text).toBe("Say hello");
|
expect(userMessages[0].text).toBe("Say hello");
|
||||||
|
|
||||||
// Branch from the first message
|
// Branch from the first message
|
||||||
const result = await session.branch(userMessages[0].entryIndex);
|
const result = await session.branch(userMessages[0].entryId);
|
||||||
expect(result.selectedText).toBe("Say hello");
|
expect(result.selectedText).toBe("Say hello");
|
||||||
expect(result.cancelled).toBe(false);
|
expect(result.cancelled).toBe(false);
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => {
|
||||||
expect(session.messages.length).toBeGreaterThan(0);
|
expect(session.messages.length).toBeGreaterThan(0);
|
||||||
|
|
||||||
// Branch from the first message
|
// Branch from the first message
|
||||||
const result = await session.branch(userMessages[0].entryIndex);
|
const result = await session.branch(userMessages[0].entryId);
|
||||||
expect(result.selectedText).toBe("Say hi");
|
expect(result.selectedText).toBe("Say hi");
|
||||||
expect(result.cancelled).toBe(false);
|
expect(result.cancelled).toBe(false);
|
||||||
|
|
||||||
|
|
@ -143,7 +143,7 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => {
|
||||||
|
|
||||||
// Branch from second message (keeps first message + response)
|
// Branch from second message (keeps first message + response)
|
||||||
const secondMessage = userMessages[1];
|
const secondMessage = userMessages[1];
|
||||||
const result = await session.branch(secondMessage.entryIndex);
|
const result = await session.branch(secondMessage.entryId);
|
||||||
expect(result.selectedText).toBe("Say two");
|
expect(result.selectedText).toBe("Say two");
|
||||||
|
|
||||||
// After branching, should have first user message + assistant response
|
// After branching, should have first user message + assistant response
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue