mirror of
https://github.com/harivansh-afk/clanker-agent.git
synced 2026-04-16 20:01:23 +00:00
move pi-mono into companion-cloud as apps/companion-os
- Copy all pi-mono source into apps/companion-os/ - Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases - Update deploy-staging.yml to build pi from source (bun compile) before Docker build - Add apps/companion-os/** to path triggers - No more cross-repo dispatch needed Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
0250f72976
579 changed files with 206942 additions and 0 deletions
3337
packages/coding-agent/src/core/agent-session.ts
Normal file
3337
packages/coding-agent/src/core/agent-session.ts
Normal file
File diff suppressed because it is too large
Load diff
503
packages/coding-agent/src/core/auth-storage.ts
Normal file
503
packages/coding-agent/src/core/auth-storage.ts
Normal file
|
|
@ -0,0 +1,503 @@
|
|||
/**
|
||||
* Credential storage for API keys and OAuth tokens.
|
||||
* Handles loading, saving, and refreshing credentials from auth.json.
|
||||
*
|
||||
* Uses file locking to prevent race conditions when multiple pi instances
|
||||
* try to refresh tokens simultaneously.
|
||||
*/
|
||||
|
||||
import {
|
||||
getEnvApiKey,
|
||||
type OAuthCredentials,
|
||||
type OAuthLoginCallbacks,
|
||||
type OAuthProviderId,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
getOAuthApiKey,
|
||||
getOAuthProvider,
|
||||
getOAuthProviders,
|
||||
} from "@mariozechner/pi-ai/oauth";
|
||||
import {
|
||||
chmodSync,
|
||||
existsSync,
|
||||
mkdirSync,
|
||||
readFileSync,
|
||||
writeFileSync,
|
||||
} from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import lockfile from "proper-lockfile";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { resolveConfigValue } from "./resolve-config-value.js";
|
||||
|
||||
export type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
key: string;
|
||||
};
|
||||
|
||||
export type OAuthCredential = {
|
||||
type: "oauth";
|
||||
} & OAuthCredentials;
|
||||
|
||||
export type AuthCredential = ApiKeyCredential | OAuthCredential;
|
||||
|
||||
export type AuthStorageData = Record<string, AuthCredential>;
|
||||
|
||||
type LockResult<T> = {
|
||||
result: T;
|
||||
next?: string;
|
||||
};
|
||||
|
||||
export interface AuthStorageBackend {
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T;
|
||||
withLockAsync<T>(
|
||||
fn: (current: string | undefined) => Promise<LockResult<T>>,
|
||||
): Promise<T>;
|
||||
}
|
||||
|
||||
export class FileAuthStorageBackend implements AuthStorageBackend {
|
||||
constructor(private authPath: string = join(getAgentDir(), "auth.json")) {}
|
||||
|
||||
private ensureParentDir(): void {
|
||||
const dir = dirname(this.authPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
}
|
||||
|
||||
private ensureFileExists(): void {
|
||||
if (!existsSync(this.authPath)) {
|
||||
writeFileSync(this.authPath, "{}", "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
}
|
||||
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
|
||||
this.ensureParentDir();
|
||||
this.ensureFileExists();
|
||||
|
||||
let release: (() => void) | undefined;
|
||||
try {
|
||||
release = lockfile.lockSync(this.authPath, { realpath: false });
|
||||
const current = existsSync(this.authPath)
|
||||
? readFileSync(this.authPath, "utf-8")
|
||||
: undefined;
|
||||
const { result, next } = fn(current);
|
||||
if (next !== undefined) {
|
||||
writeFileSync(this.authPath, next, "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async withLockAsync<T>(
|
||||
fn: (current: string | undefined) => Promise<LockResult<T>>,
|
||||
): Promise<T> {
|
||||
this.ensureParentDir();
|
||||
this.ensureFileExists();
|
||||
|
||||
let release: (() => Promise<void>) | undefined;
|
||||
let lockCompromised = false;
|
||||
let lockCompromisedError: Error | undefined;
|
||||
const throwIfCompromised = () => {
|
||||
if (lockCompromised) {
|
||||
throw (
|
||||
lockCompromisedError ?? new Error("Auth storage lock was compromised")
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
release = await lockfile.lock(this.authPath, {
|
||||
retries: {
|
||||
retries: 10,
|
||||
factor: 2,
|
||||
minTimeout: 100,
|
||||
maxTimeout: 10000,
|
||||
randomize: true,
|
||||
},
|
||||
stale: 30000,
|
||||
onCompromised: (err) => {
|
||||
lockCompromised = true;
|
||||
lockCompromisedError = err;
|
||||
},
|
||||
});
|
||||
|
||||
throwIfCompromised();
|
||||
const current = existsSync(this.authPath)
|
||||
? readFileSync(this.authPath, "utf-8")
|
||||
: undefined;
|
||||
const { result, next } = await fn(current);
|
||||
throwIfCompromised();
|
||||
if (next !== undefined) {
|
||||
writeFileSync(this.authPath, next, "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
throwIfCompromised();
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
try {
|
||||
await release();
|
||||
} catch {
|
||||
// Ignore unlock errors when lock is compromised.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class InMemoryAuthStorageBackend implements AuthStorageBackend {
|
||||
private value: string | undefined;
|
||||
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
|
||||
const { result, next } = fn(this.value);
|
||||
if (next !== undefined) {
|
||||
this.value = next;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
async withLockAsync<T>(
|
||||
fn: (current: string | undefined) => Promise<LockResult<T>>,
|
||||
): Promise<T> {
|
||||
const { result, next } = await fn(this.value);
|
||||
if (next !== undefined) {
|
||||
this.value = next;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Credential storage backed by a JSON file.
|
||||
*/
|
||||
export class AuthStorage {
|
||||
private data: AuthStorageData = {};
|
||||
private runtimeOverrides: Map<string, string> = new Map();
|
||||
private fallbackResolver?: (provider: string) => string | undefined;
|
||||
private loadError: Error | null = null;
|
||||
private errors: Error[] = [];
|
||||
|
||||
private constructor(private storage: AuthStorageBackend) {
|
||||
this.reload();
|
||||
}
|
||||
|
||||
static create(authPath?: string): AuthStorage {
|
||||
return new AuthStorage(
|
||||
new FileAuthStorageBackend(authPath ?? join(getAgentDir(), "auth.json")),
|
||||
);
|
||||
}
|
||||
|
||||
static fromStorage(storage: AuthStorageBackend): AuthStorage {
|
||||
return new AuthStorage(storage);
|
||||
}
|
||||
|
||||
static inMemory(data: AuthStorageData = {}): AuthStorage {
|
||||
const storage = new InMemoryAuthStorageBackend();
|
||||
storage.withLock(() => ({
|
||||
result: undefined,
|
||||
next: JSON.stringify(data, null, 2),
|
||||
}));
|
||||
return AuthStorage.fromStorage(storage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a runtime API key override (not persisted to disk).
|
||||
* Used for CLI --api-key flag.
|
||||
*/
|
||||
setRuntimeApiKey(provider: string, apiKey: string): void {
|
||||
this.runtimeOverrides.set(provider, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a runtime API key override.
|
||||
*/
|
||||
removeRuntimeApiKey(provider: string): void {
|
||||
this.runtimeOverrides.delete(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a fallback resolver for API keys not found in auth.json or env vars.
|
||||
* Used for custom provider keys from models.json.
|
||||
*/
|
||||
setFallbackResolver(
|
||||
resolver: (provider: string) => string | undefined,
|
||||
): void {
|
||||
this.fallbackResolver = resolver;
|
||||
}
|
||||
|
||||
private recordError(error: unknown): void {
|
||||
const normalizedError =
|
||||
error instanceof Error ? error : new Error(String(error));
|
||||
this.errors.push(normalizedError);
|
||||
}
|
||||
|
||||
private parseStorageData(content: string | undefined): AuthStorageData {
|
||||
if (!content) {
|
||||
return {};
|
||||
}
|
||||
return JSON.parse(content) as AuthStorageData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload credentials from storage.
|
||||
*/
|
||||
reload(): void {
|
||||
let content: string | undefined;
|
||||
try {
|
||||
this.storage.withLock((current) => {
|
||||
content = current;
|
||||
return { result: undefined };
|
||||
});
|
||||
this.data = this.parseStorageData(content);
|
||||
this.loadError = null;
|
||||
} catch (error) {
|
||||
this.loadError = error as Error;
|
||||
this.recordError(error);
|
||||
}
|
||||
}
|
||||
|
||||
private persistProviderChange(
|
||||
provider: string,
|
||||
credential: AuthCredential | undefined,
|
||||
): void {
|
||||
if (this.loadError) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.storage.withLock((current) => {
|
||||
const currentData = this.parseStorageData(current);
|
||||
const merged: AuthStorageData = { ...currentData };
|
||||
if (credential) {
|
||||
merged[provider] = credential;
|
||||
} else {
|
||||
delete merged[provider];
|
||||
}
|
||||
return { result: undefined, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get credential for a provider.
|
||||
*/
|
||||
get(provider: string): AuthCredential | undefined {
|
||||
return this.data[provider] ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set credential for a provider.
|
||||
*/
|
||||
set(provider: string, credential: AuthCredential): void {
|
||||
this.data[provider] = credential;
|
||||
this.persistProviderChange(provider, credential);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove credential for a provider.
|
||||
*/
|
||||
remove(provider: string): void {
|
||||
delete this.data[provider];
|
||||
this.persistProviderChange(provider, undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* List all providers with credentials.
|
||||
*/
|
||||
list(): string[] {
|
||||
return Object.keys(this.data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if credentials exist for a provider in auth.json.
|
||||
*/
|
||||
has(provider: string): boolean {
|
||||
return provider in this.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if any form of auth is configured for a provider.
|
||||
* Unlike getApiKey(), this doesn't refresh OAuth tokens.
|
||||
*/
|
||||
hasAuth(provider: string): boolean {
|
||||
if (this.runtimeOverrides.has(provider)) return true;
|
||||
if (this.data[provider]) return true;
|
||||
if (getEnvApiKey(provider)) return true;
|
||||
if (this.fallbackResolver?.(provider)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all credentials (for passing to getOAuthApiKey).
|
||||
*/
|
||||
getAll(): AuthStorageData {
|
||||
return { ...this.data };
|
||||
}
|
||||
|
||||
drainErrors(): Error[] {
|
||||
const drained = [...this.errors];
|
||||
this.errors = [];
|
||||
return drained;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login to an OAuth provider.
|
||||
*/
|
||||
async login(
|
||||
providerId: OAuthProviderId,
|
||||
callbacks: OAuthLoginCallbacks,
|
||||
): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
const credentials = await provider.login(callbacks);
|
||||
this.set(providerId, { type: "oauth", ...credentials });
|
||||
}
|
||||
|
||||
/**
|
||||
* Logout from a provider.
|
||||
*/
|
||||
logout(provider: string): void {
|
||||
this.remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OAuth token with backend locking to prevent race conditions.
|
||||
* Multiple pi instances may try to refresh simultaneously when tokens expire.
|
||||
*/
|
||||
private async refreshOAuthTokenWithLock(
|
||||
providerId: OAuthProviderId,
|
||||
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const result = await this.storage.withLockAsync(async (current) => {
|
||||
const currentData = this.parseStorageData(current);
|
||||
this.data = currentData;
|
||||
this.loadError = null;
|
||||
|
||||
const cred = currentData[providerId];
|
||||
if (cred?.type !== "oauth") {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
if (Date.now() < cred.expires) {
|
||||
return {
|
||||
result: { apiKey: provider.getApiKey(cred), newCredentials: cred },
|
||||
};
|
||||
}
|
||||
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(currentData)) {
|
||||
if (value.type === "oauth") {
|
||||
oauthCreds[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
const refreshed = await getOAuthApiKey(providerId, oauthCreds);
|
||||
if (!refreshed) {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
const merged: AuthStorageData = {
|
||||
...currentData,
|
||||
[providerId]: { type: "oauth", ...refreshed.newCredentials },
|
||||
};
|
||||
this.data = merged;
|
||||
this.loadError = null;
|
||||
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* Priority:
|
||||
* 1. Runtime override (CLI --api-key)
|
||||
* 2. API key from auth.json
|
||||
* 3. OAuth token from auth.json (auto-refreshed with locking)
|
||||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(providerId: string): Promise<string | undefined> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(providerId);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[providerId];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
return resolveConfigValue(cred.key);
|
||||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
// Unknown OAuth provider, can't get API key
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
const needsRefresh = Date.now() >= cred.expires;
|
||||
|
||||
if (needsRefresh) {
|
||||
// Use locked refresh to prevent race conditions
|
||||
try {
|
||||
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||
if (result) {
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
// Refresh failed - re-read file to check if another instance succeeded
|
||||
this.reload();
|
||||
const updatedCred = this.data[providerId];
|
||||
|
||||
if (
|
||||
updatedCred?.type === "oauth" &&
|
||||
Date.now() < updatedCred.expires
|
||||
) {
|
||||
// Another instance refreshed successfully, use those credentials
|
||||
return provider.getApiKey(updatedCred);
|
||||
}
|
||||
|
||||
// Refresh truly failed - return undefined so model discovery skips this provider
|
||||
// User can /login to re-authenticate (credentials preserved for retry)
|
||||
return undefined;
|
||||
}
|
||||
} else {
|
||||
// Token not expired, use current access token
|
||||
return provider.getApiKey(cred);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
const envKey = getEnvApiKey(providerId);
|
||||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(providerId) ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
getOAuthProviders() {
|
||||
return getOAuthProviders();
|
||||
}
|
||||
}
|
||||
296
packages/coding-agent/src/core/bash-executor.ts
Normal file
296
packages/coding-agent/src/core/bash-executor.ts
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
/**
|
||||
* Bash command execution with streaming support and cancellation.
|
||||
*
|
||||
* This module provides a unified bash execution implementation used by:
|
||||
* - AgentSession.executeBash() for interactive and RPC modes
|
||||
* - Direct calls from modes that need bash execution
|
||||
*/
|
||||
|
||||
import { randomBytes } from "node:crypto";
|
||||
import { createWriteStream, type WriteStream } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { type ChildProcess, spawn } from "child_process";
|
||||
import stripAnsi from "strip-ansi";
|
||||
import {
|
||||
getShellConfig,
|
||||
getShellEnv,
|
||||
killProcessTree,
|
||||
sanitizeBinaryOutput,
|
||||
} from "../utils/shell.js";
|
||||
import type { BashOperations } from "./tools/bash.js";
|
||||
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BashExecutorOptions {
|
||||
/** Callback for streaming output chunks (already sanitized) */
|
||||
onChunk?: (chunk: string) => void;
|
||||
/** AbortSignal for cancellation */
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface BashResult {
|
||||
/** Combined stdout + stderr output (sanitized, possibly truncated) */
|
||||
output: string;
|
||||
/** Process exit code (undefined if killed/cancelled) */
|
||||
exitCode: number | undefined;
|
||||
/** Whether the command was cancelled via signal */
|
||||
cancelled: boolean;
|
||||
/** Whether the output was truncated */
|
||||
truncated: boolean;
|
||||
/** Path to temp file containing full output (if output exceeded truncation threshold) */
|
||||
fullOutputPath?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Implementation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Execute a bash command with optional streaming and cancellation support.
|
||||
*
|
||||
* Features:
|
||||
* - Streams sanitized output via onChunk callback
|
||||
* - Writes large output to temp file for later retrieval
|
||||
* - Supports cancellation via AbortSignal
|
||||
* - Sanitizes output (strips ANSI, removes binary garbage, normalizes newlines)
|
||||
* - Truncates output if it exceeds the default max bytes
|
||||
*
|
||||
* @param command - The bash command to execute
|
||||
* @param options - Optional streaming callback and abort signal
|
||||
* @returns Promise resolving to execution result
|
||||
*/
|
||||
export function executeBash(
|
||||
command: string,
|
||||
options?: BashExecutorOptions,
|
||||
): Promise<BashResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const { shell, args } = getShellConfig();
|
||||
const child: ChildProcess = spawn(shell, [...args, command], {
|
||||
detached: true,
|
||||
env: getShellEnv(),
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
// Track sanitized output for truncation
|
||||
const outputChunks: string[] = [];
|
||||
let outputBytes = 0;
|
||||
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
// Temp file for large output
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: WriteStream | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
// Handle abort signal
|
||||
const abortHandler = () => {
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
};
|
||||
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
// Already aborted, don't even start
|
||||
child.kill();
|
||||
resolve({
|
||||
output: "",
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: false,
|
||||
});
|
||||
return;
|
||||
}
|
||||
options.signal.addEventListener("abort", abortHandler, { once: true });
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const handleData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Sanitize once at the source: strip ANSI, replace binary garbage, normalize newlines
|
||||
const text = sanitizeBinaryOutput(
|
||||
stripAnsi(decoder.decode(data, { stream: true })),
|
||||
).replace(/\r/g, "");
|
||||
|
||||
// Start writing to temp file if exceeds threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
// Write already-buffered chunks to temp file
|
||||
for (const chunk of outputChunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(text);
|
||||
}
|
||||
|
||||
// Keep rolling buffer of sanitized text
|
||||
outputChunks.push(text);
|
||||
outputBytes += text.length;
|
||||
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
|
||||
const removed = outputChunks.shift()!;
|
||||
outputBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream to callback if provided
|
||||
if (options?.onChunk) {
|
||||
options.onChunk(text);
|
||||
}
|
||||
};
|
||||
|
||||
child.stdout?.on("data", handleData);
|
||||
child.stderr?.on("data", handleData);
|
||||
|
||||
child.on("close", (code) => {
|
||||
// Clean up abort listener
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine buffered chunks for truncation (already sanitized)
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
|
||||
// code === null means killed (cancelled)
|
||||
const cancelled = code === null;
|
||||
|
||||
resolve({
|
||||
output: truncationResult.truncated
|
||||
? truncationResult.content
|
||||
: fullOutput,
|
||||
exitCode: cancelled ? undefined : code,
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
});
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
// Clean up abort listener
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a bash command using custom BashOperations.
|
||||
* Used for remote execution (SSH, containers, etc.).
|
||||
*/
|
||||
export async function executeBashWithOperations(
|
||||
command: string,
|
||||
cwd: string,
|
||||
operations: BashOperations,
|
||||
options?: BashExecutorOptions,
|
||||
): Promise<BashResult> {
|
||||
const outputChunks: string[] = [];
|
||||
let outputBytes = 0;
|
||||
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: WriteStream | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const onData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Sanitize: strip ANSI, replace binary garbage, normalize newlines
|
||||
const text = sanitizeBinaryOutput(
|
||||
stripAnsi(decoder.decode(data, { stream: true })),
|
||||
).replace(/\r/g, "");
|
||||
|
||||
// Start writing to temp file if exceeds threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
for (const chunk of outputChunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(text);
|
||||
}
|
||||
|
||||
// Keep rolling buffer
|
||||
outputChunks.push(text);
|
||||
outputBytes += text.length;
|
||||
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
|
||||
const removed = outputChunks.shift()!;
|
||||
outputBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream to callback
|
||||
if (options?.onChunk) {
|
||||
options.onChunk(text);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const result = await operations.exec(command, cwd, {
|
||||
onData,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
const cancelled = options?.signal?.aborted ?? false;
|
||||
|
||||
return {
|
||||
output: truncationResult.truncated
|
||||
? truncationResult.content
|
||||
: fullOutput,
|
||||
exitCode: cancelled ? undefined : (result.exitCode ?? undefined),
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
} catch (err) {
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Check if it was an abort
|
||||
if (options?.signal?.aborted) {
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
return {
|
||||
output: truncationResult.truncated
|
||||
? truncationResult.content
|
||||
: fullOutput,
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
/**
|
||||
* Branch summarization for tree navigation.
|
||||
*
|
||||
* When navigating to a different point in the session tree, this generates
|
||||
* a summary of the branch being left so context isn't lost.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { completeSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createCustomMessage,
|
||||
} from "../messages.js";
|
||||
import type {
|
||||
ReadonlySessionManager,
|
||||
SessionEntry,
|
||||
} from "../session-manager.js";
|
||||
import { estimateTokens } from "./compaction.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BranchSummaryResult {
|
||||
summary?: string;
|
||||
readFiles?: string[];
|
||||
modifiedFiles?: string[];
|
||||
aborted?: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/** Details stored in BranchSummaryEntry.details for file tracking */
|
||||
export interface BranchSummaryDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
export type { FileOperations } from "./utils.js";
|
||||
|
||||
export interface BranchPreparation {
|
||||
/** Messages extracted for summarization, in chronological order */
|
||||
messages: AgentMessage[];
|
||||
/** File operations extracted from tool calls */
|
||||
fileOps: FileOperations;
|
||||
/** Total estimated tokens in messages */
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export interface CollectEntriesResult {
|
||||
/** Entries to summarize, in chronological order */
|
||||
entries: SessionEntry[];
|
||||
/** Common ancestor between old and new position, if any */
|
||||
commonAncestorId: string | null;
|
||||
}
|
||||
|
||||
export interface GenerateBranchSummaryOptions {
|
||||
/** Model to use for summarization */
|
||||
model: Model<any>;
|
||||
/** API key for the model */
|
||||
apiKey: string;
|
||||
/** Abort signal for cancellation */
|
||||
signal: AbortSignal;
|
||||
/** Optional custom instructions for summarization */
|
||||
customInstructions?: string;
|
||||
/** If true, customInstructions replaces the default prompt instead of being appended */
|
||||
replaceInstructions?: boolean;
|
||||
/** Tokens reserved for prompt + LLM response (default 16384) */
|
||||
reserveTokens?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry Collection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Collect entries that should be summarized when navigating from one position to another.
|
||||
*
|
||||
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
|
||||
* along the way. Does NOT stop at compaction boundaries - those are included and their
|
||||
* summaries become context.
|
||||
*
|
||||
* @param session - Session manager (read-only access)
|
||||
* @param oldLeafId - Current position (where we're navigating from)
|
||||
* @param targetId - Target position (where we're navigating to)
|
||||
* @returns Entries to summarize and the common ancestor
|
||||
*/
|
||||
export function collectEntriesForBranchSummary(
|
||||
session: ReadonlySessionManager,
|
||||
oldLeafId: string | null,
|
||||
targetId: string,
|
||||
): CollectEntriesResult {
|
||||
// If no old position, nothing to summarize
|
||||
if (!oldLeafId) {
|
||||
return { entries: [], commonAncestorId: null };
|
||||
}
|
||||
|
||||
// Find common ancestor (deepest node that's on both paths)
|
||||
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
|
||||
const targetPath = session.getBranch(targetId);
|
||||
|
||||
// targetPath is root-first, so iterate backwards to find deepest common ancestor
|
||||
let commonAncestorId: string | null = null;
|
||||
for (let i = targetPath.length - 1; i >= 0; i--) {
|
||||
if (oldPath.has(targetPath[i].id)) {
|
||||
commonAncestorId = targetPath[i].id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collect entries from old leaf back to common ancestor
|
||||
const entries: SessionEntry[] = [];
|
||||
let current: string | null = oldLeafId;
|
||||
|
||||
while (current && current !== commonAncestorId) {
|
||||
const entry = session.getEntry(current);
|
||||
if (!entry) break;
|
||||
entries.push(entry);
|
||||
current = entry.parentId;
|
||||
}
|
||||
|
||||
// Reverse to get chronological order
|
||||
entries.reverse();
|
||||
|
||||
return { entries, commonAncestorId };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry to Message Conversion
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from a session entry.
|
||||
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
switch (entry.type) {
|
||||
case "message":
|
||||
// Skip tool results - context is in assistant's tool call
|
||||
if (entry.message.role === "toolResult") return undefined;
|
||||
return entry.message;
|
||||
|
||||
case "custom_message":
|
||||
return createCustomMessage(
|
||||
entry.customType,
|
||||
entry.content,
|
||||
entry.display,
|
||||
entry.details,
|
||||
entry.timestamp,
|
||||
);
|
||||
|
||||
case "branch_summary":
|
||||
return createBranchSummaryMessage(
|
||||
entry.summary,
|
||||
entry.fromId,
|
||||
entry.timestamp,
|
||||
);
|
||||
|
||||
case "compaction":
|
||||
return createCompactionSummaryMessage(
|
||||
entry.summary,
|
||||
entry.tokensBefore,
|
||||
entry.timestamp,
|
||||
);
|
||||
|
||||
// These don't contribute to conversation content
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "custom":
|
||||
case "label":
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare entries for summarization with token budget.
|
||||
*
|
||||
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
|
||||
* This ensures we keep the most recent context when the branch is too long.
|
||||
*
|
||||
* Also collects file operations from:
|
||||
* - Tool calls in assistant messages
|
||||
* - Existing branch_summary entries' details (for cumulative tracking)
|
||||
*
|
||||
* @param entries - Entries in chronological order
|
||||
* @param tokenBudget - Maximum tokens to include (0 = no limit)
|
||||
*/
|
||||
export function prepareBranchEntries(
|
||||
entries: SessionEntry[],
|
||||
tokenBudget: number = 0,
|
||||
): BranchPreparation {
|
||||
const messages: AgentMessage[] = [];
|
||||
const fileOps = createFileOps();
|
||||
let totalTokens = 0;
|
||||
|
||||
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
|
||||
// This ensures we capture cumulative file tracking from nested branch summaries
|
||||
// Only extract from pi-generated summaries (fromHook !== true), not extension-generated ones
|
||||
for (const entry of entries) {
|
||||
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
|
||||
const details = entry.details as BranchSummaryDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
// Modified files go into both edited and written for proper deduplication
|
||||
for (const f of details.modifiedFiles) {
|
||||
fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: walk from newest to oldest, adding messages until token budget
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
const message = getMessageFromEntry(entry);
|
||||
if (!message) continue;
|
||||
|
||||
// Extract file ops from assistant messages (tool calls)
|
||||
extractFileOpsFromMessage(message, fileOps);
|
||||
|
||||
const tokens = estimateTokens(message);
|
||||
|
||||
// Check budget before adding
|
||||
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
|
||||
// If this is a summary entry, try to fit it anyway as it's important context
|
||||
if (entry.type === "compaction" || entry.type === "branch_summary") {
|
||||
if (totalTokens < tokenBudget * 0.9) {
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
}
|
||||
// Stop - we've hit the budget
|
||||
break;
|
||||
}
|
||||
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
|
||||
return { messages, fileOps, totalTokens };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summary Generation
|
||||
// ============================================================================
|
||||
|
||||
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
|
||||
Summary of that exploration:
|
||||
|
||||
`;
|
||||
|
||||
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What was the user trying to accomplish in this branch?]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Work that was started but not finished]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [What should happen next to continue this work]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of abandoned branch entries.
|
||||
*
|
||||
* @param entries - Session entries to summarize (chronological order)
|
||||
* @param options - Generation options
|
||||
*/
|
||||
export async function generateBranchSummary(
|
||||
entries: SessionEntry[],
|
||||
options: GenerateBranchSummaryOptions,
|
||||
): Promise<BranchSummaryResult> {
|
||||
const {
|
||||
model,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
replaceInstructions,
|
||||
reserveTokens = 16384,
|
||||
} = options;
|
||||
|
||||
// Token budget = context window minus reserved space for prompt + response
|
||||
const contextWindow = model.contextWindow || 128000;
|
||||
const tokenBudget = contextWindow - reserveTokens;
|
||||
|
||||
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
|
||||
|
||||
if (messages.length === 0) {
|
||||
return { summary: "No content to summarize" };
|
||||
}
|
||||
|
||||
// Transform to LLM-compatible messages, then serialize to text
|
||||
// Serialization prevents the model from treating it as a conversation to continue
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build prompt
|
||||
let instructions: string;
|
||||
if (replaceInstructions && customInstructions) {
|
||||
instructions = customInstructions;
|
||||
} else if (customInstructions) {
|
||||
instructions = `${BRANCH_SUMMARY_PROMPT}\n\nAdditional focus: ${customInstructions}`;
|
||||
} else {
|
||||
instructions = BRANCH_SUMMARY_PROMPT;
|
||||
}
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
// Call LLM for summarization
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{
|
||||
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
|
||||
messages: summarizationMessages,
|
||||
},
|
||||
{ apiKey, signal, maxTokens: 2048 },
|
||||
);
|
||||
|
||||
// Check if aborted or errored
|
||||
if (response.stopReason === "aborted") {
|
||||
return { aborted: true };
|
||||
}
|
||||
if (response.stopReason === "error") {
|
||||
return { error: response.errorMessage || "Summarization failed" };
|
||||
}
|
||||
|
||||
let summary = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
// Prepend preamble to provide context about the branch summary
|
||||
summary = BRANCH_SUMMARY_PREAMBLE + summary;
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
return {
|
||||
summary: summary || "No summary generated",
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
};
|
||||
}
|
||||
899
packages/coding-agent/src/core/compaction/compaction.ts
Normal file
899
packages/coding-agent/src/core/compaction/compaction.ts
Normal file
|
|
@ -0,0 +1,899 @@
|
|||
/**
|
||||
* Context compaction for long sessions.
|
||||
*
|
||||
* Pure functions for compaction logic. The session manager handles I/O,
|
||||
* and after compaction the session is reloaded.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
|
||||
import { completeSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createCustomMessage,
|
||||
} from "../messages.js";
|
||||
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
/** Details stored in CompactionEntry.details for file tracking */
|
||||
export interface CompactionDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from messages and previous compaction entries.
|
||||
*/
|
||||
function extractFileOperations(
|
||||
messages: AgentMessage[],
|
||||
entries: SessionEntry[],
|
||||
prevCompactionIndex: number,
|
||||
): FileOperations {
|
||||
const fileOps = createFileOps();
|
||||
|
||||
// Collect from previous compaction's details (if pi-generated)
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
|
||||
if (!prevCompaction.fromHook && prevCompaction.details) {
|
||||
// fromHook field kept for session file compatibility
|
||||
const details = prevCompaction.details as CompactionDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
for (const f of details.modifiedFiles) fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract from tool calls in messages
|
||||
for (const msg of messages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
|
||||
return fileOps;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Extraction
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from an entry if it produces one.
|
||||
* Returns undefined for entries that don't contribute to LLM context.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
if (entry.type === "message") {
|
||||
return entry.message;
|
||||
}
|
||||
if (entry.type === "custom_message") {
|
||||
return createCustomMessage(
|
||||
entry.customType,
|
||||
entry.content,
|
||||
entry.display,
|
||||
entry.details,
|
||||
entry.timestamp,
|
||||
);
|
||||
}
|
||||
if (entry.type === "branch_summary") {
|
||||
return createBranchSummaryMessage(
|
||||
entry.summary,
|
||||
entry.fromId,
|
||||
entry.timestamp,
|
||||
);
|
||||
}
|
||||
if (entry.type === "compaction") {
|
||||
return createCompactionSummaryMessage(
|
||||
entry.summary,
|
||||
entry.tokensBefore,
|
||||
entry.timestamp,
|
||||
);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
|
||||
export interface CompactionResult<T = unknown> {
|
||||
summary: string;
|
||||
firstKeptEntryId: string;
|
||||
tokensBefore: number;
|
||||
/** Extension-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
|
||||
details?: T;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionSettings {
|
||||
enabled: boolean;
|
||||
reserveTokens: number;
|
||||
keepRecentTokens: number;
|
||||
}
|
||||
|
||||
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
|
||||
enabled: true,
|
||||
reserveTokens: 16384,
|
||||
keepRecentTokens: 20000,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Token calculation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Calculate total context tokens from usage.
|
||||
* Uses the native totalTokens field when available, falls back to computing from components.
|
||||
*/
|
||||
export function calculateContextTokens(usage: Usage): number {
|
||||
return (
|
||||
usage.totalTokens ||
|
||||
usage.input + usage.output + usage.cacheRead + usage.cacheWrite
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get usage from an assistant message if available.
|
||||
* Skips aborted and error messages as they don't have valid usage data.
|
||||
*/
|
||||
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
|
||||
if (msg.role === "assistant" && "usage" in msg) {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (
|
||||
assistantMsg.stopReason !== "aborted" &&
|
||||
assistantMsg.stopReason !== "error" &&
|
||||
assistantMsg.usage
|
||||
) {
|
||||
return assistantMsg.usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the last non-aborted assistant message usage from session entries.
|
||||
*/
|
||||
export function getLastAssistantUsage(
|
||||
entries: SessionEntry[],
|
||||
): Usage | undefined {
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const usage = getAssistantUsage(entry.message);
|
||||
if (usage) return usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export interface ContextUsageEstimate {
|
||||
tokens: number;
|
||||
usageTokens: number;
|
||||
trailingTokens: number;
|
||||
lastUsageIndex: number | null;
|
||||
}
|
||||
|
||||
function getLastAssistantUsageInfo(
|
||||
messages: AgentMessage[],
|
||||
): { usage: Usage; index: number } | undefined {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const usage = getAssistantUsage(messages[i]);
|
||||
if (usage) return { usage, index: i };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate context tokens from messages, using the last assistant usage when available.
|
||||
* If there are messages after the last usage, estimate their tokens with estimateTokens.
|
||||
*/
|
||||
export function estimateContextTokens(
|
||||
messages: AgentMessage[],
|
||||
): ContextUsageEstimate {
|
||||
const usageInfo = getLastAssistantUsageInfo(messages);
|
||||
|
||||
if (!usageInfo) {
|
||||
let estimated = 0;
|
||||
for (const message of messages) {
|
||||
estimated += estimateTokens(message);
|
||||
}
|
||||
return {
|
||||
tokens: estimated,
|
||||
usageTokens: 0,
|
||||
trailingTokens: estimated,
|
||||
lastUsageIndex: null,
|
||||
};
|
||||
}
|
||||
|
||||
const usageTokens = calculateContextTokens(usageInfo.usage);
|
||||
let trailingTokens = 0;
|
||||
for (let i = usageInfo.index + 1; i < messages.length; i++) {
|
||||
trailingTokens += estimateTokens(messages[i]);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: usageTokens + trailingTokens,
|
||||
usageTokens,
|
||||
trailingTokens,
|
||||
lastUsageIndex: usageInfo.index,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if compaction should trigger based on context usage.
|
||||
*/
|
||||
export function shouldCompact(
|
||||
contextTokens: number,
|
||||
contextWindow: number,
|
||||
settings: CompactionSettings,
|
||||
): boolean {
|
||||
if (!settings.enabled) return false;
|
||||
return contextTokens > contextWindow - settings.reserveTokens;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cut point detection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Estimate token count for a message using chars/4 heuristic.
|
||||
* This is conservative (overestimates tokens).
|
||||
*/
|
||||
export function estimateTokens(message: AgentMessage): number {
|
||||
let chars = 0;
|
||||
|
||||
switch (message.role) {
|
||||
case "user": {
|
||||
const content = (
|
||||
message as { content: string | Array<{ type: string; text?: string }> }
|
||||
).content;
|
||||
if (typeof content === "string") {
|
||||
chars = content.length;
|
||||
} else if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "assistant": {
|
||||
const assistant = message as AssistantMessage;
|
||||
for (const block of assistant.content) {
|
||||
if (block.type === "text") {
|
||||
chars += block.text.length;
|
||||
} else if (block.type === "thinking") {
|
||||
chars += block.thinking.length;
|
||||
} else if (block.type === "toolCall") {
|
||||
chars += block.name.length + JSON.stringify(block.arguments).length;
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "custom":
|
||||
case "toolResult": {
|
||||
if (typeof message.content === "string") {
|
||||
chars = message.content.length;
|
||||
} else {
|
||||
for (const block of message.content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
if (block.type === "image") {
|
||||
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "bashExecution": {
|
||||
chars = message.command.length + message.output.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "branchSummary":
|
||||
case "compactionSummary": {
|
||||
chars = message.summary.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
|
||||
* Never cut at tool results (they must follow their tool call).
|
||||
* When we cut at an assistant message with tool calls, its tool results follow it
|
||||
* and will be kept.
|
||||
* BashExecutionMessage is treated like a user message (user-initiated context).
|
||||
*/
|
||||
function findValidCutPoints(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
): number[] {
|
||||
const cutPoints: number[] = [];
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const entry = entries[i];
|
||||
switch (entry.type) {
|
||||
case "message": {
|
||||
const role = entry.message.role;
|
||||
switch (role) {
|
||||
case "bashExecution":
|
||||
case "custom":
|
||||
case "branchSummary":
|
||||
case "compactionSummary":
|
||||
case "user":
|
||||
case "assistant":
|
||||
cutPoints.push(i);
|
||||
break;
|
||||
case "toolResult":
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "compaction":
|
||||
case "branch_summary":
|
||||
case "custom":
|
||||
case "custom_message":
|
||||
case "label":
|
||||
}
|
||||
// branch_summary and custom_message are user-role messages, valid cut points
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
cutPoints.push(i);
|
||||
}
|
||||
}
|
||||
return cutPoints;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
|
||||
* Returns -1 if no turn start found before the index.
|
||||
* BashExecutionMessage is treated like a user message for turn boundaries.
|
||||
*/
|
||||
export function findTurnStartIndex(
|
||||
entries: SessionEntry[],
|
||||
entryIndex: number,
|
||||
startIndex: number,
|
||||
): number {
|
||||
for (let i = entryIndex; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
// branch_summary and custom_message are user-role messages, can start a turn
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
return i;
|
||||
}
|
||||
if (entry.type === "message") {
|
||||
const role = entry.message.role;
|
||||
if (role === "user" || role === "bashExecution") {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export interface CutPointResult {
|
||||
/** Index of first entry to keep */
|
||||
firstKeptEntryIndex: number;
|
||||
/** Index of user message that starts the turn being split, or -1 if not splitting */
|
||||
turnStartIndex: number;
|
||||
/** Whether this cut splits a turn (cut point is not a user message) */
|
||||
isSplitTurn: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
|
||||
*
|
||||
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
|
||||
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
|
||||
*
|
||||
* Can cut at user OR assistant messages (never tool results). When cutting at an
|
||||
* assistant message with tool calls, its tool results come after and will be kept.
|
||||
*
|
||||
* Returns CutPointResult with:
|
||||
* - firstKeptEntryIndex: the entry index to start keeping from
|
||||
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
|
||||
* - isSplitTurn: whether we're cutting in the middle of a turn
|
||||
*
|
||||
* Only considers entries between `startIndex` and `endIndex` (exclusive).
|
||||
*/
|
||||
export function findCutPoint(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
keepRecentTokens: number,
|
||||
): CutPointResult {
|
||||
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
|
||||
|
||||
if (cutPoints.length === 0) {
|
||||
return {
|
||||
firstKeptEntryIndex: startIndex,
|
||||
turnStartIndex: -1,
|
||||
isSplitTurn: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Walk backwards from newest, accumulating estimated message sizes
|
||||
let accumulatedTokens = 0;
|
||||
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
|
||||
|
||||
for (let i = endIndex - 1; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type !== "message") continue;
|
||||
|
||||
// Estimate this message's size
|
||||
const messageTokens = estimateTokens(entry.message);
|
||||
accumulatedTokens += messageTokens;
|
||||
|
||||
// Check if we've exceeded the budget
|
||||
if (accumulatedTokens >= keepRecentTokens) {
|
||||
// Find the closest valid cut point at or after this entry
|
||||
for (let c = 0; c < cutPoints.length; c++) {
|
||||
if (cutPoints[c] >= i) {
|
||||
cutIndex = cutPoints[c];
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
|
||||
while (cutIndex > startIndex) {
|
||||
const prevEntry = entries[cutIndex - 1];
|
||||
// Stop at session header or compaction boundaries
|
||||
if (prevEntry.type === "compaction") {
|
||||
break;
|
||||
}
|
||||
if (prevEntry.type === "message") {
|
||||
// Stop if we hit any message
|
||||
break;
|
||||
}
|
||||
// Include this non-message entry (bash, settings change, etc.)
|
||||
cutIndex--;
|
||||
}
|
||||
|
||||
// Determine if this is a split turn
|
||||
const cutEntry = entries[cutIndex];
|
||||
const isUserMessage =
|
||||
cutEntry.type === "message" && cutEntry.message.role === "user";
|
||||
const turnStartIndex = isUserMessage
|
||||
? -1
|
||||
: findTurnStartIndex(entries, cutIndex, startIndex);
|
||||
|
||||
return {
|
||||
firstKeptEntryIndex: cutIndex,
|
||||
turnStartIndex,
|
||||
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization
|
||||
// ============================================================================
|
||||
|
||||
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned by user]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [Ordered list of what should happen next]
|
||||
|
||||
## Critical Context
|
||||
- [Any data, examples, or references needed to continue]
|
||||
- [Or "(none)" if not applicable]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
|
||||
|
||||
Update the existing structured summary with new information. RULES:
|
||||
- PRESERVE all existing information from the previous summary
|
||||
- ADD new progress, decisions, and context from the new messages
|
||||
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
|
||||
- UPDATE "Next Steps" based on what was accomplished
|
||||
- PRESERVE exact file paths, function names, and error messages
|
||||
- If something is no longer relevant, you may remove it
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[Preserve existing goals, add new ones if the task expanded]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Preserve existing, add new ones discovered]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Include previously done items AND newly completed items]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work - update based on progress]
|
||||
|
||||
### Blocked
|
||||
- [Current blockers - remove if resolved]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
|
||||
|
||||
## Next Steps
|
||||
1. [Update based on current state]
|
||||
|
||||
## Critical Context
|
||||
- [Preserve important context, add new if needed]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of the conversation using the LLM.
|
||||
* If previousSummary is provided, uses the update prompt to merge.
|
||||
*/
|
||||
export async function generateSummary(
|
||||
currentMessages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
previousSummary?: string,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.8 * reserveTokens);
|
||||
|
||||
// Use update prompt if we have a previous summary, otherwise initial prompt
|
||||
let basePrompt = previousSummary
|
||||
? UPDATE_SUMMARIZATION_PROMPT
|
||||
: SUMMARIZATION_PROMPT;
|
||||
if (customInstructions) {
|
||||
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
|
||||
}
|
||||
|
||||
// Serialize conversation to text so model doesn't try to continue it
|
||||
// Convert to LLM messages first (handles custom types like bashExecution, custom, etc.)
|
||||
const llmMessages = convertToLlm(currentMessages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build the prompt with conversation wrapped in tags
|
||||
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
|
||||
if (previousSummary) {
|
||||
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
|
||||
}
|
||||
promptText += basePrompt;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const completionOptions = model.reasoning
|
||||
? { maxTokens, signal, apiKey, reasoning: "high" as const }
|
||||
: { maxTokens, signal, apiKey };
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{
|
||||
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
|
||||
messages: summarizationMessages,
|
||||
},
|
||||
completionOptions,
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(
|
||||
`Summarization failed: ${response.errorMessage || "Unknown error"}`,
|
||||
);
|
||||
}
|
||||
|
||||
const textContent = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
return textContent;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compaction Preparation (for extensions)
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionPreparation {
|
||||
/** UUID of first entry to keep */
|
||||
firstKeptEntryId: string;
|
||||
/** Messages that will be summarized and discarded */
|
||||
messagesToSummarize: AgentMessage[];
|
||||
/** Messages that will be turned into turn prefix summary (if splitting) */
|
||||
turnPrefixMessages: AgentMessage[];
|
||||
/** Whether this is a split turn (cut point in middle of turn) */
|
||||
isSplitTurn: boolean;
|
||||
tokensBefore: number;
|
||||
/** Summary from previous compaction, for iterative update */
|
||||
previousSummary?: string;
|
||||
/** File operations extracted from messagesToSummarize */
|
||||
fileOps: FileOperations;
|
||||
/** Compaction settions from settings.jsonl */
|
||||
settings: CompactionSettings;
|
||||
}
|
||||
|
||||
export function prepareCompaction(
|
||||
pathEntries: SessionEntry[],
|
||||
settings: CompactionSettings,
|
||||
): CompactionPreparation | undefined {
|
||||
if (
|
||||
pathEntries.length > 0 &&
|
||||
pathEntries[pathEntries.length - 1].type === "compaction"
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = pathEntries.length - 1; i >= 0; i--) {
|
||||
if (pathEntries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = pathEntries.length;
|
||||
|
||||
const usageStart = prevCompactionIndex >= 0 ? prevCompactionIndex : 0;
|
||||
const usageMessages: AgentMessage[] = [];
|
||||
for (let i = usageStart; i < boundaryEnd; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) usageMessages.push(msg);
|
||||
}
|
||||
const tokensBefore = estimateContextTokens(usageMessages).tokens;
|
||||
|
||||
const cutPoint = findCutPoint(
|
||||
pathEntries,
|
||||
boundaryStart,
|
||||
boundaryEnd,
|
||||
settings.keepRecentTokens,
|
||||
);
|
||||
|
||||
// Get UUID of first kept entry
|
||||
const firstKeptEntry = pathEntries[cutPoint.firstKeptEntryIndex];
|
||||
if (!firstKeptEntry?.id) {
|
||||
return undefined; // Session needs migration
|
||||
}
|
||||
const firstKeptEntryId = firstKeptEntry.id;
|
||||
|
||||
const historyEnd = cutPoint.isSplitTurn
|
||||
? cutPoint.turnStartIndex
|
||||
: cutPoint.firstKeptEntryIndex;
|
||||
|
||||
// Messages to summarize (will be discarded after summary)
|
||||
const messagesToSummarize: AgentMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) messagesToSummarize.push(msg);
|
||||
}
|
||||
|
||||
// Messages for turn prefix summary (if splitting a turn)
|
||||
const turnPrefixMessages: AgentMessage[] = [];
|
||||
if (cutPoint.isSplitTurn) {
|
||||
for (
|
||||
let i = cutPoint.turnStartIndex;
|
||||
i < cutPoint.firstKeptEntryIndex;
|
||||
i++
|
||||
) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) turnPrefixMessages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Get previous summary for iterative update
|
||||
let previousSummary: string | undefined;
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = pathEntries[prevCompactionIndex] as CompactionEntry;
|
||||
previousSummary = prevCompaction.summary;
|
||||
}
|
||||
|
||||
// Extract file operations from messages and previous compaction
|
||||
const fileOps = extractFileOperations(
|
||||
messagesToSummarize,
|
||||
pathEntries,
|
||||
prevCompactionIndex,
|
||||
);
|
||||
|
||||
// Also extract file ops from turn prefix if splitting
|
||||
if (cutPoint.isSplitTurn) {
|
||||
for (const msg of turnPrefixMessages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
firstKeptEntryId,
|
||||
messagesToSummarize,
|
||||
turnPrefixMessages,
|
||||
isSplitTurn: cutPoint.isSplitTurn,
|
||||
tokensBefore,
|
||||
previousSummary,
|
||||
fileOps,
|
||||
settings,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main compaction function
|
||||
// ============================================================================
|
||||
|
||||
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
|
||||
|
||||
Summarize the prefix to provide context for the retained suffix:
|
||||
|
||||
## Original Request
|
||||
[What did the user ask for in this turn?]
|
||||
|
||||
## Early Progress
|
||||
- [Key decisions and work done in the prefix]
|
||||
|
||||
## Context for Suffix
|
||||
- [Information needed to understand the retained recent work]
|
||||
|
||||
Be concise. Focus on what's needed to understand the kept suffix.`;
|
||||
|
||||
/**
|
||||
* Generate summaries for compaction using prepared data.
|
||||
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
|
||||
*
|
||||
* @param preparation - Pre-calculated preparation from prepareCompaction()
|
||||
* @param customInstructions - Optional custom focus for the summary
|
||||
*/
|
||||
export async function compact(
|
||||
preparation: CompactionPreparation,
|
||||
model: Model<any>,
|
||||
apiKey: string,
|
||||
customInstructions?: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<CompactionResult> {
|
||||
const {
|
||||
firstKeptEntryId,
|
||||
messagesToSummarize,
|
||||
turnPrefixMessages,
|
||||
isSplitTurn,
|
||||
tokensBefore,
|
||||
previousSummary,
|
||||
fileOps,
|
||||
settings,
|
||||
} = preparation;
|
||||
|
||||
// Generate summaries (can be parallel if both needed) and merge into one
|
||||
let summary: string;
|
||||
|
||||
if (isSplitTurn && turnPrefixMessages.length > 0) {
|
||||
// Generate both summaries in parallel
|
||||
const [historyResult, turnPrefixResult] = await Promise.all([
|
||||
messagesToSummarize.length > 0
|
||||
? generateSummary(
|
||||
messagesToSummarize,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
)
|
||||
: Promise.resolve("No prior history."),
|
||||
generateTurnPrefixSummary(
|
||||
turnPrefixMessages,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
),
|
||||
]);
|
||||
// Merge into single summary
|
||||
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
|
||||
} else {
|
||||
// Just generate history summary
|
||||
summary = await generateSummary(
|
||||
messagesToSummarize,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
);
|
||||
}
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
if (!firstKeptEntryId) {
|
||||
throw new Error(
|
||||
"First kept entry has no UUID - session may need migration",
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
summary,
|
||||
firstKeptEntryId,
|
||||
tokensBefore,
|
||||
details: { readFiles, modifiedFiles } as CompactionDetails,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a summary for a turn prefix (when splitting a turn).
|
||||
*/
|
||||
async function generateTurnPrefixSummary(
|
||||
messages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${TURN_PREFIX_SUMMARIZATION_PROMPT}`;
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{
|
||||
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
|
||||
messages: summarizationMessages,
|
||||
},
|
||||
{ maxTokens, signal, apiKey },
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(
|
||||
`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`,
|
||||
);
|
||||
}
|
||||
|
||||
return response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
7
packages/coding-agent/src/core/compaction/index.ts
Normal file
7
packages/coding-agent/src/core/compaction/index.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* Compaction and summarization utilities.
|
||||
*/
|
||||
|
||||
export * from "./branch-summarization.js";
|
||||
export * from "./compaction.js";
|
||||
export * from "./utils.js";
|
||||
167
packages/coding-agent/src/core/compaction/utils.ts
Normal file
167
packages/coding-agent/src/core/compaction/utils.ts
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
/**
|
||||
* Shared utilities for compaction and branch summarization.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Message } from "@mariozechner/pi-ai";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
export interface FileOperations {
|
||||
read: Set<string>;
|
||||
written: Set<string>;
|
||||
edited: Set<string>;
|
||||
}
|
||||
|
||||
export function createFileOps(): FileOperations {
|
||||
return {
|
||||
read: new Set(),
|
||||
written: new Set(),
|
||||
edited: new Set(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from tool calls in an assistant message.
|
||||
*/
|
||||
export function extractFileOpsFromMessage(
|
||||
message: AgentMessage,
|
||||
fileOps: FileOperations,
|
||||
): void {
|
||||
if (message.role !== "assistant") return;
|
||||
if (!("content" in message) || !Array.isArray(message.content)) return;
|
||||
|
||||
for (const block of message.content) {
|
||||
if (typeof block !== "object" || block === null) continue;
|
||||
if (!("type" in block) || block.type !== "toolCall") continue;
|
||||
if (!("arguments" in block) || !("name" in block)) continue;
|
||||
|
||||
const args = block.arguments as Record<string, unknown> | undefined;
|
||||
if (!args) continue;
|
||||
|
||||
const path = typeof args.path === "string" ? args.path : undefined;
|
||||
if (!path) continue;
|
||||
|
||||
switch (block.name) {
|
||||
case "read":
|
||||
fileOps.read.add(path);
|
||||
break;
|
||||
case "write":
|
||||
fileOps.written.add(path);
|
||||
break;
|
||||
case "edit":
|
||||
fileOps.edited.add(path);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute final file lists from file operations.
|
||||
* Returns readFiles (files only read, not modified) and modifiedFiles.
|
||||
*/
|
||||
export function computeFileLists(fileOps: FileOperations): {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
} {
|
||||
const modified = new Set([...fileOps.edited, ...fileOps.written]);
|
||||
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
|
||||
const modifiedFiles = [...modified].sort();
|
||||
return { readFiles: readOnly, modifiedFiles };
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file operations as XML tags for summary.
|
||||
*/
|
||||
export function formatFileOperations(
|
||||
readFiles: string[],
|
||||
modifiedFiles: string[],
|
||||
): string {
|
||||
const sections: string[] = [];
|
||||
if (readFiles.length > 0) {
|
||||
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
|
||||
}
|
||||
if (modifiedFiles.length > 0) {
|
||||
sections.push(
|
||||
`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`,
|
||||
);
|
||||
}
|
||||
if (sections.length === 0) return "";
|
||||
return `\n\n${sections.join("\n\n")}`;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Serialization
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Serialize LLM messages to text for summarization.
|
||||
* This prevents the model from treating it as a conversation to continue.
|
||||
* Call convertToLlm() first to handle custom message types.
|
||||
*/
|
||||
export function serializeConversation(messages: Message[]): string {
|
||||
const parts: string[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
const content =
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: msg.content
|
||||
.filter(
|
||||
(c): c is { type: "text"; text: string } => c.type === "text",
|
||||
)
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) parts.push(`[User]: ${content}`);
|
||||
} else if (msg.role === "assistant") {
|
||||
const textParts: string[] = [];
|
||||
const thinkingParts: string[] = [];
|
||||
const toolCalls: string[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
textParts.push(block.text);
|
||||
} else if (block.type === "thinking") {
|
||||
thinkingParts.push(block.thinking);
|
||||
} else if (block.type === "toolCall") {
|
||||
const args = block.arguments as Record<string, unknown>;
|
||||
const argsStr = Object.entries(args)
|
||||
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
|
||||
.join(", ");
|
||||
toolCalls.push(`${block.name}(${argsStr})`);
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingParts.length > 0) {
|
||||
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
parts.push(`[Assistant]: ${textParts.join("\n")}`);
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const content = msg.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) {
|
||||
parts.push(`[Tool result]: ${content}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join("\n\n");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization System Prompt
|
||||
// ============================================================================
|
||||
|
||||
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
|
||||
|
||||
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;
|
||||
3
packages/coding-agent/src/core/defaults.ts
Normal file
3
packages/coding-agent/src/core/defaults.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
|
||||
export const DEFAULT_THINKING_LEVEL: ThinkingLevel = "medium";
|
||||
15
packages/coding-agent/src/core/diagnostics.ts
Normal file
15
packages/coding-agent/src/core/diagnostics.ts
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
export interface ResourceCollision {
|
||||
resourceType: "extension" | "skill" | "prompt" | "theme";
|
||||
name: string; // skill name, command/tool/flag name, prompt name, theme name
|
||||
winnerPath: string;
|
||||
loserPath: string;
|
||||
winnerSource?: string; // e.g., "npm:foo", "git:...", "local"
|
||||
loserSource?: string;
|
||||
}
|
||||
|
||||
export interface ResourceDiagnostic {
|
||||
type: "warning" | "error" | "collision";
|
||||
message: string;
|
||||
path?: string;
|
||||
collision?: ResourceCollision;
|
||||
}
|
||||
33
packages/coding-agent/src/core/event-bus.ts
Normal file
33
packages/coding-agent/src/core/event-bus.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import { EventEmitter } from "node:events";
|
||||
|
||||
export interface EventBus {
|
||||
emit(channel: string, data: unknown): void;
|
||||
on(channel: string, handler: (data: unknown) => void): () => void;
|
||||
}
|
||||
|
||||
export interface EventBusController extends EventBus {
|
||||
clear(): void;
|
||||
}
|
||||
|
||||
export function createEventBus(): EventBusController {
|
||||
const emitter = new EventEmitter();
|
||||
return {
|
||||
emit: (channel, data) => {
|
||||
emitter.emit(channel, data);
|
||||
},
|
||||
on: (channel, handler) => {
|
||||
const safeHandler = async (data: unknown) => {
|
||||
try {
|
||||
await handler(data);
|
||||
} catch (err) {
|
||||
console.error(`Event handler error (${channel}):`, err);
|
||||
}
|
||||
};
|
||||
emitter.on(channel, safeHandler);
|
||||
return () => emitter.off(channel, safeHandler);
|
||||
},
|
||||
clear: () => {
|
||||
emitter.removeAllListeners();
|
||||
},
|
||||
};
|
||||
}
|
||||
104
packages/coding-agent/src/core/exec.ts
Normal file
104
packages/coding-agent/src/core/exec.ts
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Shared command execution utilities for extensions and custom tools.
|
||||
*/
|
||||
|
||||
import { spawn } from "node:child_process";
|
||||
|
||||
/**
|
||||
* Options for executing shell commands.
|
||||
*/
|
||||
export interface ExecOptions {
|
||||
/** AbortSignal to cancel the command */
|
||||
signal?: AbortSignal;
|
||||
/** Timeout in milliseconds */
|
||||
timeout?: number;
|
||||
/** Working directory */
|
||||
cwd?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of executing a shell command.
|
||||
*/
|
||||
export interface ExecResult {
|
||||
stdout: string;
|
||||
stderr: string;
|
||||
code: number;
|
||||
killed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a shell command and return stdout/stderr/code.
|
||||
* Supports timeout and abort signal.
|
||||
*/
|
||||
export async function execCommand(
|
||||
command: string,
|
||||
args: string[],
|
||||
cwd: string,
|
||||
options?: ExecOptions,
|
||||
): Promise<ExecResult> {
|
||||
return new Promise((resolve) => {
|
||||
const proc = spawn(command, args, {
|
||||
cwd,
|
||||
shell: false,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let killed = false;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const killProcess = () => {
|
||||
if (!killed) {
|
||||
killed = true;
|
||||
proc.kill("SIGTERM");
|
||||
// Force kill after 5 seconds if SIGTERM doesn't work
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) {
|
||||
proc.kill("SIGKILL");
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
killProcess();
|
||||
} else {
|
||||
options.signal.addEventListener("abort", killProcess, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle timeout
|
||||
if (options?.timeout && options.timeout > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
killProcess();
|
||||
}, options.timeout);
|
||||
}
|
||||
|
||||
proc.stdout?.on("data", (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr?.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: code ?? 0, killed });
|
||||
});
|
||||
|
||||
proc.on("error", (_err) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: 1, killed });
|
||||
});
|
||||
});
|
||||
}
|
||||
271
packages/coding-agent/src/core/export-html/ansi-to-html.ts
Normal file
271
packages/coding-agent/src/core/export-html/ansi-to-html.ts
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
/**
|
||||
* ANSI escape code to HTML converter.
|
||||
*
|
||||
* Converts terminal ANSI color/style codes to HTML with inline styles.
|
||||
* Supports:
|
||||
* - Standard foreground colors (30-37) and bright variants (90-97)
|
||||
* - Standard background colors (40-47) and bright variants (100-107)
|
||||
* - 256-color palette (38;5;N and 48;5;N)
|
||||
* - RGB true color (38;2;R;G;B and 48;2;R;G;B)
|
||||
* - Text styles: bold (1), dim (2), italic (3), underline (4)
|
||||
* - Reset (0)
|
||||
*/
|
||||
|
||||
// Standard ANSI color palette (0-15)
|
||||
const ANSI_COLORS = [
|
||||
"#000000", // 0: black
|
||||
"#800000", // 1: red
|
||||
"#008000", // 2: green
|
||||
"#808000", // 3: yellow
|
||||
"#000080", // 4: blue
|
||||
"#800080", // 5: magenta
|
||||
"#008080", // 6: cyan
|
||||
"#c0c0c0", // 7: white
|
||||
"#808080", // 8: bright black
|
||||
"#ff0000", // 9: bright red
|
||||
"#00ff00", // 10: bright green
|
||||
"#ffff00", // 11: bright yellow
|
||||
"#0000ff", // 12: bright blue
|
||||
"#ff00ff", // 13: bright magenta
|
||||
"#00ffff", // 14: bright cyan
|
||||
"#ffffff", // 15: bright white
|
||||
];
|
||||
|
||||
/**
|
||||
* Convert 256-color index to hex.
|
||||
*/
|
||||
function color256ToHex(index: number): string {
|
||||
// Standard colors (0-15)
|
||||
if (index < 16) {
|
||||
return ANSI_COLORS[index];
|
||||
}
|
||||
|
||||
// Color cube (16-231): 6x6x6 = 216 colors
|
||||
if (index < 232) {
|
||||
const cubeIndex = index - 16;
|
||||
const r = Math.floor(cubeIndex / 36);
|
||||
const g = Math.floor((cubeIndex % 36) / 6);
|
||||
const b = cubeIndex % 6;
|
||||
const toComponent = (n: number) => (n === 0 ? 0 : 55 + n * 40);
|
||||
const toHex = (n: number) => toComponent(n).toString(16).padStart(2, "0");
|
||||
return `#${toHex(r)}${toHex(g)}${toHex(b)}`;
|
||||
}
|
||||
|
||||
// Grayscale (232-255): 24 shades
|
||||
const gray = 8 + (index - 232) * 10;
|
||||
const grayHex = gray.toString(16).padStart(2, "0");
|
||||
return `#${grayHex}${grayHex}${grayHex}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Escape HTML special characters.
|
||||
*/
|
||||
function escapeHtml(text: string): string {
|
||||
return text
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
interface TextStyle {
|
||||
fg: string | null;
|
||||
bg: string | null;
|
||||
bold: boolean;
|
||||
dim: boolean;
|
||||
italic: boolean;
|
||||
underline: boolean;
|
||||
}
|
||||
|
||||
function createEmptyStyle(): TextStyle {
|
||||
return {
|
||||
fg: null,
|
||||
bg: null,
|
||||
bold: false,
|
||||
dim: false,
|
||||
italic: false,
|
||||
underline: false,
|
||||
};
|
||||
}
|
||||
|
||||
function styleToInlineCSS(style: TextStyle): string {
|
||||
const parts: string[] = [];
|
||||
if (style.fg) parts.push(`color:${style.fg}`);
|
||||
if (style.bg) parts.push(`background-color:${style.bg}`);
|
||||
if (style.bold) parts.push("font-weight:bold");
|
||||
if (style.dim) parts.push("opacity:0.6");
|
||||
if (style.italic) parts.push("font-style:italic");
|
||||
if (style.underline) parts.push("text-decoration:underline");
|
||||
return parts.join(";");
|
||||
}
|
||||
|
||||
function hasStyle(style: TextStyle): boolean {
|
||||
return (
|
||||
style.fg !== null ||
|
||||
style.bg !== null ||
|
||||
style.bold ||
|
||||
style.dim ||
|
||||
style.italic ||
|
||||
style.underline
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse ANSI SGR (Select Graphic Rendition) codes and update style.
|
||||
*/
|
||||
function applySgrCode(params: number[], style: TextStyle): void {
|
||||
let i = 0;
|
||||
while (i < params.length) {
|
||||
const code = params[i];
|
||||
|
||||
if (code === 0) {
|
||||
// Reset all
|
||||
style.fg = null;
|
||||
style.bg = null;
|
||||
style.bold = false;
|
||||
style.dim = false;
|
||||
style.italic = false;
|
||||
style.underline = false;
|
||||
} else if (code === 1) {
|
||||
style.bold = true;
|
||||
} else if (code === 2) {
|
||||
style.dim = true;
|
||||
} else if (code === 3) {
|
||||
style.italic = true;
|
||||
} else if (code === 4) {
|
||||
style.underline = true;
|
||||
} else if (code === 22) {
|
||||
// Reset bold/dim
|
||||
style.bold = false;
|
||||
style.dim = false;
|
||||
} else if (code === 23) {
|
||||
style.italic = false;
|
||||
} else if (code === 24) {
|
||||
style.underline = false;
|
||||
} else if (code >= 30 && code <= 37) {
|
||||
// Standard foreground colors
|
||||
style.fg = ANSI_COLORS[code - 30];
|
||||
} else if (code === 38) {
|
||||
// Extended foreground color
|
||||
if (params[i + 1] === 5 && params.length > i + 2) {
|
||||
// 256-color: 38;5;N
|
||||
style.fg = color256ToHex(params[i + 2]);
|
||||
i += 2;
|
||||
} else if (params[i + 1] === 2 && params.length > i + 4) {
|
||||
// RGB: 38;2;R;G;B
|
||||
const r = params[i + 2];
|
||||
const g = params[i + 3];
|
||||
const b = params[i + 4];
|
||||
style.fg = `rgb(${r},${g},${b})`;
|
||||
i += 4;
|
||||
}
|
||||
} else if (code === 39) {
|
||||
// Default foreground
|
||||
style.fg = null;
|
||||
} else if (code >= 40 && code <= 47) {
|
||||
// Standard background colors
|
||||
style.bg = ANSI_COLORS[code - 40];
|
||||
} else if (code === 48) {
|
||||
// Extended background color
|
||||
if (params[i + 1] === 5 && params.length > i + 2) {
|
||||
// 256-color: 48;5;N
|
||||
style.bg = color256ToHex(params[i + 2]);
|
||||
i += 2;
|
||||
} else if (params[i + 1] === 2 && params.length > i + 4) {
|
||||
// RGB: 48;2;R;G;B
|
||||
const r = params[i + 2];
|
||||
const g = params[i + 3];
|
||||
const b = params[i + 4];
|
||||
style.bg = `rgb(${r},${g},${b})`;
|
||||
i += 4;
|
||||
}
|
||||
} else if (code === 49) {
|
||||
// Default background
|
||||
style.bg = null;
|
||||
} else if (code >= 90 && code <= 97) {
|
||||
// Bright foreground colors
|
||||
style.fg = ANSI_COLORS[code - 90 + 8];
|
||||
} else if (code >= 100 && code <= 107) {
|
||||
// Bright background colors
|
||||
style.bg = ANSI_COLORS[code - 100 + 8];
|
||||
}
|
||||
// Ignore unrecognized codes
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Match ANSI escape sequences: ESC[ followed by params and ending with 'm'
|
||||
const ANSI_REGEX = /\x1b\[([\d;]*)m/g;
|
||||
|
||||
/**
|
||||
* Convert ANSI-escaped text to HTML with inline styles.
|
||||
*/
|
||||
export function ansiToHtml(text: string): string {
|
||||
const style = createEmptyStyle();
|
||||
let result = "";
|
||||
let lastIndex = 0;
|
||||
let inSpan = false;
|
||||
|
||||
// Reset regex state
|
||||
ANSI_REGEX.lastIndex = 0;
|
||||
|
||||
let match = ANSI_REGEX.exec(text);
|
||||
while (match !== null) {
|
||||
// Add text before this escape sequence
|
||||
const beforeText = text.slice(lastIndex, match.index);
|
||||
if (beforeText) {
|
||||
result += escapeHtml(beforeText);
|
||||
}
|
||||
|
||||
// Parse SGR parameters
|
||||
const paramStr = match[1];
|
||||
const params = paramStr
|
||||
? paramStr.split(";").map((p) => parseInt(p, 10) || 0)
|
||||
: [0];
|
||||
|
||||
// Close existing span if we have one
|
||||
if (inSpan) {
|
||||
result += "</span>";
|
||||
inSpan = false;
|
||||
}
|
||||
|
||||
// Apply the codes
|
||||
applySgrCode(params, style);
|
||||
|
||||
// Open new span if we have any styling
|
||||
if (hasStyle(style)) {
|
||||
result += `<span style="${styleToInlineCSS(style)}">`;
|
||||
inSpan = true;
|
||||
}
|
||||
|
||||
lastIndex = match.index + match[0].length;
|
||||
match = ANSI_REGEX.exec(text);
|
||||
}
|
||||
|
||||
// Add remaining text
|
||||
const remainingText = text.slice(lastIndex);
|
||||
if (remainingText) {
|
||||
result += escapeHtml(remainingText);
|
||||
}
|
||||
|
||||
// Close any open span
|
||||
if (inSpan) {
|
||||
result += "</span>";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert array of ANSI-escaped lines to HTML.
|
||||
* Each line is wrapped in a div element.
|
||||
*/
|
||||
export function ansiLinesToHtml(lines: string[]): string {
|
||||
return lines
|
||||
.map(
|
||||
(line) => `<div class="ansi-line">${ansiToHtml(line) || " "}</div>`,
|
||||
)
|
||||
.join("\n");
|
||||
}
|
||||
353
packages/coding-agent/src/core/export-html/index.ts
Normal file
353
packages/coding-agent/src/core/export-html/index.ts
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
import type { AgentState } from "@mariozechner/pi-agent-core";
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import { basename, join } from "path";
|
||||
import { APP_NAME, getExportTemplateDir } from "../../config.js";
|
||||
import {
|
||||
getResolvedThemeColors,
|
||||
getThemeExportColors,
|
||||
} from "../../modes/interactive/theme/theme.js";
|
||||
import type { ToolInfo } from "../extensions/types.js";
|
||||
import type { SessionEntry } from "../session-manager.js";
|
||||
import { SessionManager } from "../session-manager.js";
|
||||
|
||||
/**
|
||||
* Interface for rendering custom tools to HTML.
|
||||
* Used by agent-session to pre-render extension tool output.
|
||||
*/
|
||||
export interface ToolHtmlRenderer {
|
||||
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderCall(toolName: string, args: unknown): string | undefined;
|
||||
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): string | undefined;
|
||||
}
|
||||
|
||||
/** Pre-rendered HTML for a custom tool call and result */
|
||||
interface RenderedToolHtml {
|
||||
callHtml?: string;
|
||||
resultHtml?: string;
|
||||
}
|
||||
|
||||
export interface ExportOptions {
|
||||
outputPath?: string;
|
||||
themeName?: string;
|
||||
/** Optional tool renderer for custom tools */
|
||||
toolRenderer?: ToolHtmlRenderer;
|
||||
}
|
||||
|
||||
/** Parse a color string to RGB values. Supports hex (#RRGGBB) and rgb(r,g,b) formats. */
|
||||
function parseColor(
|
||||
color: string,
|
||||
): { r: number; g: number; b: number } | undefined {
|
||||
const hexMatch = color.match(
|
||||
/^#([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$/,
|
||||
);
|
||||
if (hexMatch) {
|
||||
return {
|
||||
r: Number.parseInt(hexMatch[1], 16),
|
||||
g: Number.parseInt(hexMatch[2], 16),
|
||||
b: Number.parseInt(hexMatch[3], 16),
|
||||
};
|
||||
}
|
||||
const rgbMatch = color.match(
|
||||
/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$/,
|
||||
);
|
||||
if (rgbMatch) {
|
||||
return {
|
||||
r: Number.parseInt(rgbMatch[1], 10),
|
||||
g: Number.parseInt(rgbMatch[2], 10),
|
||||
b: Number.parseInt(rgbMatch[3], 10),
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Calculate relative luminance of a color (0-1, higher = lighter). */
|
||||
function getLuminance(r: number, g: number, b: number): number {
|
||||
const toLinear = (c: number) => {
|
||||
const s = c / 255;
|
||||
return s <= 0.03928 ? s / 12.92 : ((s + 0.055) / 1.055) ** 2.4;
|
||||
};
|
||||
return 0.2126 * toLinear(r) + 0.7152 * toLinear(g) + 0.0722 * toLinear(b);
|
||||
}
|
||||
|
||||
/** Adjust color brightness. Factor > 1 lightens, < 1 darkens. */
|
||||
function adjustBrightness(color: string, factor: number): string {
|
||||
const parsed = parseColor(color);
|
||||
if (!parsed) return color;
|
||||
const adjust = (c: number) =>
|
||||
Math.min(255, Math.max(0, Math.round(c * factor)));
|
||||
return `rgb(${adjust(parsed.r)}, ${adjust(parsed.g)}, ${adjust(parsed.b)})`;
|
||||
}
|
||||
|
||||
/** Derive export background colors from a base color (e.g., userMessageBg). */
|
||||
function deriveExportColors(baseColor: string): {
|
||||
pageBg: string;
|
||||
cardBg: string;
|
||||
infoBg: string;
|
||||
} {
|
||||
const parsed = parseColor(baseColor);
|
||||
if (!parsed) {
|
||||
return {
|
||||
pageBg: "rgb(24, 24, 30)",
|
||||
cardBg: "rgb(30, 30, 36)",
|
||||
infoBg: "rgb(60, 55, 40)",
|
||||
};
|
||||
}
|
||||
|
||||
const luminance = getLuminance(parsed.r, parsed.g, parsed.b);
|
||||
const isLight = luminance > 0.5;
|
||||
|
||||
if (isLight) {
|
||||
return {
|
||||
pageBg: adjustBrightness(baseColor, 0.96),
|
||||
cardBg: baseColor,
|
||||
infoBg: `rgb(${Math.min(255, parsed.r + 10)}, ${Math.min(255, parsed.g + 5)}, ${Math.max(0, parsed.b - 20)})`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
pageBg: adjustBrightness(baseColor, 0.7),
|
||||
cardBg: adjustBrightness(baseColor, 0.85),
|
||||
infoBg: `rgb(${Math.min(255, parsed.r + 20)}, ${Math.min(255, parsed.g + 15)}, ${parsed.b})`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate CSS custom property declarations from theme colors.
|
||||
*/
|
||||
function generateThemeVars(themeName?: string): string {
|
||||
const colors = getResolvedThemeColors(themeName);
|
||||
const lines: string[] = [];
|
||||
for (const [key, value] of Object.entries(colors)) {
|
||||
lines.push(`--${key}: ${value};`);
|
||||
}
|
||||
|
||||
// Use explicit theme export colors if available, otherwise derive from userMessageBg
|
||||
const themeExport = getThemeExportColors(themeName);
|
||||
const userMessageBg = colors.userMessageBg || "#343541";
|
||||
const derivedColors = deriveExportColors(userMessageBg);
|
||||
|
||||
lines.push(`--exportPageBg: ${themeExport.pageBg ?? derivedColors.pageBg};`);
|
||||
lines.push(`--exportCardBg: ${themeExport.cardBg ?? derivedColors.cardBg};`);
|
||||
lines.push(`--exportInfoBg: ${themeExport.infoBg ?? derivedColors.infoBg};`);
|
||||
|
||||
return lines.join("\n ");
|
||||
}
|
||||
|
||||
interface SessionData {
|
||||
header: ReturnType<SessionManager["getHeader"]>;
|
||||
entries: ReturnType<SessionManager["getEntries"]>;
|
||||
leafId: string | null;
|
||||
systemPrompt?: string;
|
||||
tools?: ToolInfo[];
|
||||
/** Pre-rendered HTML for custom tool calls/results, keyed by tool call ID */
|
||||
renderedTools?: Record<string, RenderedToolHtml>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Core HTML generation logic shared by both export functions.
|
||||
*/
|
||||
function generateHtml(sessionData: SessionData, themeName?: string): string {
|
||||
const templateDir = getExportTemplateDir();
|
||||
const template = readFileSync(join(templateDir, "template.html"), "utf-8");
|
||||
const templateCss = readFileSync(join(templateDir, "template.css"), "utf-8");
|
||||
const templateJs = readFileSync(join(templateDir, "template.js"), "utf-8");
|
||||
const markedJs = readFileSync(
|
||||
join(templateDir, "vendor", "marked.min.js"),
|
||||
"utf-8",
|
||||
);
|
||||
const hljsJs = readFileSync(
|
||||
join(templateDir, "vendor", "highlight.min.js"),
|
||||
"utf-8",
|
||||
);
|
||||
|
||||
const themeVars = generateThemeVars(themeName);
|
||||
const colors = getResolvedThemeColors(themeName);
|
||||
const exportColors = deriveExportColors(colors.userMessageBg || "#343541");
|
||||
const bodyBg = exportColors.pageBg;
|
||||
const containerBg = exportColors.cardBg;
|
||||
const infoBg = exportColors.infoBg;
|
||||
|
||||
// Base64 encode session data to avoid escaping issues
|
||||
const sessionDataBase64 = Buffer.from(JSON.stringify(sessionData)).toString(
|
||||
"base64",
|
||||
);
|
||||
|
||||
// Build the CSS with theme variables injected
|
||||
const css = templateCss
|
||||
.replace("{{THEME_VARS}}", themeVars)
|
||||
.replace("{{BODY_BG}}", bodyBg)
|
||||
.replace("{{CONTAINER_BG}}", containerBg)
|
||||
.replace("{{INFO_BG}}", infoBg);
|
||||
|
||||
return template
|
||||
.replace("{{CSS}}", css)
|
||||
.replace("{{JS}}", templateJs)
|
||||
.replace("{{SESSION_DATA}}", sessionDataBase64)
|
||||
.replace("{{MARKED_JS}}", markedJs)
|
||||
.replace("{{HIGHLIGHT_JS}}", hljsJs);
|
||||
}
|
||||
|
||||
/** Built-in tool names that have custom rendering in template.js */
|
||||
const BUILTIN_TOOLS = new Set([
|
||||
"bash",
|
||||
"read",
|
||||
"write",
|
||||
"edit",
|
||||
"ls",
|
||||
"find",
|
||||
"grep",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Pre-render custom tools to HTML using their TUI renderers.
|
||||
*/
|
||||
function preRenderCustomTools(
|
||||
entries: SessionEntry[],
|
||||
toolRenderer: ToolHtmlRenderer,
|
||||
): Record<string, RenderedToolHtml> {
|
||||
const renderedTools: Record<string, RenderedToolHtml> = {};
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.type !== "message") continue;
|
||||
const msg = entry.message;
|
||||
|
||||
// Find tool calls in assistant messages
|
||||
if (msg.role === "assistant" && Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "toolCall" && !BUILTIN_TOOLS.has(block.name)) {
|
||||
const callHtml = toolRenderer.renderCall(block.name, block.arguments);
|
||||
if (callHtml) {
|
||||
renderedTools[block.id] = { callHtml };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find tool results
|
||||
if (msg.role === "toolResult" && msg.toolCallId) {
|
||||
const toolName = msg.toolName || "";
|
||||
// Only render if we have a pre-rendered call OR it's not a built-in tool
|
||||
const existing = renderedTools[msg.toolCallId];
|
||||
if (existing || !BUILTIN_TOOLS.has(toolName)) {
|
||||
const resultHtml = toolRenderer.renderResult(
|
||||
toolName,
|
||||
msg.content,
|
||||
msg.details,
|
||||
msg.isError || false,
|
||||
);
|
||||
if (resultHtml) {
|
||||
renderedTools[msg.toolCallId] = {
|
||||
...existing,
|
||||
resultHtml,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return renderedTools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export session to HTML using SessionManager and AgentState.
|
||||
* Used by TUI's /export command.
|
||||
*/
|
||||
export async function exportSessionToHtml(
|
||||
sm: SessionManager,
|
||||
state?: AgentState,
|
||||
options?: ExportOptions | string,
|
||||
): Promise<string> {
|
||||
const opts: ExportOptions =
|
||||
typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
const sessionFile = sm.getSessionFile();
|
||||
if (!sessionFile) {
|
||||
throw new Error("Cannot export in-memory session to HTML");
|
||||
}
|
||||
if (!existsSync(sessionFile)) {
|
||||
throw new Error("Nothing to export yet - start a conversation first");
|
||||
}
|
||||
|
||||
const entries = sm.getEntries();
|
||||
|
||||
// Pre-render custom tools if a tool renderer is provided
|
||||
let renderedTools: Record<string, RenderedToolHtml> | undefined;
|
||||
if (opts.toolRenderer) {
|
||||
renderedTools = preRenderCustomTools(entries, opts.toolRenderer);
|
||||
// Only include if we actually rendered something
|
||||
if (Object.keys(renderedTools).length === 0) {
|
||||
renderedTools = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
const sessionData: SessionData = {
|
||||
header: sm.getHeader(),
|
||||
entries,
|
||||
leafId: sm.getLeafId(),
|
||||
systemPrompt: state?.systemPrompt,
|
||||
tools: state?.tools?.map((t) => ({
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.parameters,
|
||||
})),
|
||||
renderedTools,
|
||||
};
|
||||
|
||||
const html = generateHtml(sessionData, opts.themeName);
|
||||
|
||||
let outputPath = opts.outputPath;
|
||||
if (!outputPath) {
|
||||
const sessionBasename = basename(sessionFile, ".jsonl");
|
||||
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
|
||||
}
|
||||
|
||||
writeFileSync(outputPath, html, "utf8");
|
||||
return outputPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export session file to HTML (standalone, without AgentState).
|
||||
* Used by CLI for exporting arbitrary session files.
|
||||
*/
|
||||
export async function exportFromFile(
|
||||
inputPath: string,
|
||||
options?: ExportOptions | string,
|
||||
): Promise<string> {
|
||||
const opts: ExportOptions =
|
||||
typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
if (!existsSync(inputPath)) {
|
||||
throw new Error(`File not found: ${inputPath}`);
|
||||
}
|
||||
|
||||
const sm = SessionManager.open(inputPath);
|
||||
|
||||
const sessionData: SessionData = {
|
||||
header: sm.getHeader(),
|
||||
entries: sm.getEntries(),
|
||||
leafId: sm.getLeafId(),
|
||||
systemPrompt: undefined,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
const html = generateHtml(sessionData, opts.themeName);
|
||||
|
||||
let outputPath = opts.outputPath;
|
||||
if (!outputPath) {
|
||||
const inputBasename = basename(inputPath, ".jsonl");
|
||||
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
|
||||
}
|
||||
|
||||
writeFileSync(outputPath, html, "utf8");
|
||||
return outputPath;
|
||||
}
|
||||
971
packages/coding-agent/src/core/export-html/template.css
Normal file
971
packages/coding-agent/src/core/export-html/template.css
Normal file
|
|
@ -0,0 +1,971 @@
|
|||
:root {
|
||||
{{THEME_VARS}}
|
||||
--body-bg: {{BODY_BG}};
|
||||
--container-bg: {{CONTAINER_BG}};
|
||||
--info-bg: {{INFO_BG}};
|
||||
}
|
||||
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
|
||||
:root {
|
||||
--line-height: 18px; /* 12px font * 1.5 */
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
|
||||
font-size: 12px;
|
||||
line-height: var(--line-height);
|
||||
color: var(--text);
|
||||
background: var(--body-bg);
|
||||
}
|
||||
|
||||
#app {
|
||||
display: flex;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
#sidebar {
|
||||
width: 400px;
|
||||
background: var(--container-bg);
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
height: 100vh;
|
||||
border-right: 1px solid var(--dim);
|
||||
}
|
||||
|
||||
.sidebar-header {
|
||||
padding: 8px 12px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.sidebar-controls {
|
||||
padding: 8px 8px 4px 8px;
|
||||
}
|
||||
|
||||
.sidebar-search {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
padding: 4px 8px;
|
||||
font-size: 11px;
|
||||
font-family: inherit;
|
||||
background: var(--body-bg);
|
||||
color: var(--text);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.sidebar-filters {
|
||||
display: flex;
|
||||
padding: 4px 8px 8px 8px;
|
||||
gap: 4px;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.sidebar-search:focus {
|
||||
outline: none;
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.sidebar-search::placeholder {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.filter-btn {
|
||||
padding: 3px 8px;
|
||||
font-size: 10px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.filter-btn:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
.filter-btn.active {
|
||||
background: var(--accent);
|
||||
color: var(--body-bg);
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.sidebar-close {
|
||||
display: none;
|
||||
padding: 3px 8px;
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.sidebar-close:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
.tree-container {
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.tree-node {
|
||||
padding: 0 8px;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
font-size: 11px;
|
||||
line-height: 13px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.tree-node:hover {
|
||||
background: var(--selectedBg);
|
||||
}
|
||||
|
||||
.tree-node.active {
|
||||
background: var(--selectedBg);
|
||||
}
|
||||
|
||||
.tree-node.active .tree-content {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tree-node.in-path {
|
||||
background: color-mix(in srgb, var(--accent) 10%, transparent);
|
||||
}
|
||||
|
||||
.tree-node:not(.in-path) {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.tree-node:not(.in-path):hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.tree-prefix {
|
||||
color: var(--muted);
|
||||
flex-shrink: 0;
|
||||
font-family: monospace;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.tree-marker {
|
||||
color: var(--accent);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.tree-content {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tree-role-user {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.tree-role-assistant {
|
||||
color: var(--success);
|
||||
}
|
||||
|
||||
.tree-role-tool {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tree-muted {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tree-error {
|
||||
color: var(--error);
|
||||
}
|
||||
|
||||
.tree-compaction {
|
||||
color: var(--borderAccent);
|
||||
}
|
||||
|
||||
.tree-branch-summary {
|
||||
color: var(--warning);
|
||||
}
|
||||
|
||||
.tree-custom-message {
|
||||
color: var(--customMessageLabel);
|
||||
}
|
||||
|
||||
.tree-status {
|
||||
padding: 4px 12px;
|
||||
font-size: 10px;
|
||||
color: var(--muted);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Main content */
|
||||
#content {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: var(--line-height) calc(var(--line-height) * 2);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#content > * {
|
||||
width: 100%;
|
||||
max-width: 800px;
|
||||
}
|
||||
|
||||
/* Help bar */
|
||||
.help-bar {
|
||||
font-size: 11px;
|
||||
color: var(--warning);
|
||||
margin-bottom: var(--line-height);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.download-json-btn {
|
||||
font-size: 10px;
|
||||
padding: 2px 8px;
|
||||
background: var(--container-bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 3px;
|
||||
color: var(--text);
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.download-json-btn:hover {
|
||||
background: var(--hover);
|
||||
border-color: var(--borderAccent);
|
||||
}
|
||||
|
||||
/* Header */
|
||||
.header {
|
||||
background: var(--container-bg);
|
||||
border-radius: 4px;
|
||||
padding: var(--line-height);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.header h1 {
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
color: var(--borderAccent);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.header-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
color: var(--dim);
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
font-weight: 600;
|
||||
margin-right: 8px;
|
||||
min-width: 100px;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: var(--text);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* Messages */
|
||||
#messages {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--line-height);
|
||||
}
|
||||
|
||||
.message-timestamp {
|
||||
font-size: 10px;
|
||||
color: var(--dim);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.user-message {
|
||||
background: var(--userMessageBg);
|
||||
color: var(--userMessageText);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.assistant-message {
|
||||
padding: 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
/* Copy link button - appears on hover */
|
||||
.copy-link-btn {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
padding: 6px;
|
||||
background: var(--container-bg);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 4px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s, background 0.15s, color 0.15s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.user-message:hover .copy-link-btn,
|
||||
.assistant-message:hover .copy-link-btn {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.copy-link-btn:hover {
|
||||
background: var(--accent);
|
||||
color: var(--body-bg);
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.copy-link-btn.copied {
|
||||
background: var(--success, #22c55e);
|
||||
color: white;
|
||||
border-color: var(--success, #22c55e);
|
||||
}
|
||||
|
||||
/* Highlight effect for deep-linked messages */
|
||||
.user-message.highlight,
|
||||
.assistant-message.highlight {
|
||||
animation: highlight-pulse 2s ease-out;
|
||||
}
|
||||
|
||||
@keyframes highlight-pulse {
|
||||
0% {
|
||||
box-shadow: 0 0 0 3px var(--accent);
|
||||
}
|
||||
100% {
|
||||
box-shadow: 0 0 0 0 transparent;
|
||||
}
|
||||
}
|
||||
|
||||
.assistant-message > .message-timestamp {
|
||||
padding-left: var(--line-height);
|
||||
}
|
||||
|
||||
.assistant-text {
|
||||
padding: var(--line-height);
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
.message-timestamp + .assistant-text,
|
||||
.message-timestamp + .thinking-block {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-block + .assistant-text {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-text {
|
||||
padding: var(--line-height);
|
||||
color: var(--thinkingText);
|
||||
font-style: italic;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.message-timestamp + .thinking-block .thinking-text,
|
||||
.message-timestamp + .thinking-block .thinking-collapsed {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-collapsed {
|
||||
display: none;
|
||||
padding: var(--line-height);
|
||||
color: var(--thinkingText);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Tool execution */
|
||||
.tool-execution {
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.tool-execution + .tool-execution {
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.assistant-text + .tool-execution {
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-execution.pending { background: var(--toolPendingBg); }
|
||||
.tool-execution.success { background: var(--toolSuccessBg); }
|
||||
.tool-execution.error { background: var(--toolErrorBg); }
|
||||
|
||||
.tool-header, .tool-name {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tool-path {
|
||||
color: var(--accent);
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.line-numbers {
|
||||
color: var(--warning);
|
||||
}
|
||||
|
||||
.line-count {
|
||||
color: var(--dim);
|
||||
}
|
||||
|
||||
.tool-command {
|
||||
font-weight: bold;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tool-output {
|
||||
margin-top: var(--line-height);
|
||||
color: var(--toolOutput);
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
font-family: inherit;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.tool-output > div,
|
||||
.output-preview,
|
||||
.output-full {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
line-height: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-output pre {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: inherit;
|
||||
color: inherit;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
|
||||
.tool-output code {
|
||||
padding: 0;
|
||||
background: none;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-output.expandable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.tool-output.expandable:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.tool-output.expandable .output-full {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-output.expandable.expanded .output-preview {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-output.expandable.expanded .output-full {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.ansi-line {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.tool-images {
|
||||
}
|
||||
|
||||
.tool-image {
|
||||
max-width: 100%;
|
||||
max-height: 500px;
|
||||
border-radius: 4px;
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
.expand-hint {
|
||||
color: var(--toolOutput);
|
||||
}
|
||||
|
||||
/* Diff */
|
||||
.tool-diff {
|
||||
font-size: 11px;
|
||||
overflow-x: auto;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.diff-added { color: var(--toolDiffAdded); }
|
||||
.diff-removed { color: var(--toolDiffRemoved); }
|
||||
.diff-context { color: var(--toolDiffContext); }
|
||||
|
||||
/* Model change */
|
||||
.model-change {
|
||||
padding: 0 var(--line-height);
|
||||
color: var(--dim);
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.model-name {
|
||||
color: var(--borderAccent);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Compaction / Branch Summary - matches customMessage colors from TUI */
|
||||
.compaction {
|
||||
background: var(--customMessageBg);
|
||||
border-radius: 4px;
|
||||
padding: var(--line-height);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.compaction-label {
|
||||
color: var(--customMessageLabel);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.compaction-collapsed {
|
||||
color: var(--customMessageText);
|
||||
}
|
||||
|
||||
.compaction-content {
|
||||
display: none;
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.compaction.expanded .compaction-collapsed {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.compaction.expanded .compaction-content {
|
||||
display: block;
|
||||
}
|
||||
|
||||
/* System prompt */
|
||||
.system-prompt {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt.expandable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.system-prompt-header {
|
||||
font-weight: bold;
|
||||
color: var(--customMessageLabel);
|
||||
}
|
||||
|
||||
.system-prompt-preview {
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-size: 11px;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt-expand-hint {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.system-prompt-full {
|
||||
display: none;
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-size: 11px;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt.expanded .system-prompt-preview,
|
||||
.system-prompt.expanded .system-prompt-expand-hint {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.system-prompt.expanded .system-prompt-full {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.system-prompt.provider-prompt {
|
||||
border-left: 3px solid var(--warning);
|
||||
}
|
||||
|
||||
.system-prompt-note {
|
||||
font-size: 10px;
|
||||
font-style: italic;
|
||||
color: var(--muted);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
/* Tools list */
|
||||
.tools-list {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.tools-header {
|
||||
font-weight: bold;
|
||||
color: var(--customMessageLabel);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-item {
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.tool-item-name {
|
||||
font-weight: bold;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-item-desc {
|
||||
color: var(--dim);
|
||||
}
|
||||
|
||||
.tool-params-hint {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-item:has(.tool-params-hint) {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.tool-params-hint::after {
|
||||
content: '[click to show parameters]';
|
||||
}
|
||||
|
||||
.tool-item.params-expanded .tool-params-hint::after {
|
||||
content: '[hide parameters]';
|
||||
}
|
||||
|
||||
.tool-params-content {
|
||||
display: none;
|
||||
margin-top: 4px;
|
||||
margin-left: 12px;
|
||||
padding-left: 8px;
|
||||
border-left: 1px solid var(--dim);
|
||||
}
|
||||
|
||||
.tool-item.params-expanded .tool-params-content {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.tool-param {
|
||||
margin-bottom: 4px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.tool-param-name {
|
||||
font-weight: bold;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-param-type {
|
||||
color: var(--dim);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-param-required {
|
||||
color: var(--warning, #e8a838);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.tool-param-optional {
|
||||
color: var(--dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.tool-param-desc {
|
||||
color: var(--dim);
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Hook/custom messages */
|
||||
.hook-message {
|
||||
background: var(--customMessageBg);
|
||||
color: var(--customMessageText);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.hook-type {
|
||||
color: var(--customMessageLabel);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Branch summary */
|
||||
.branch-summary {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.branch-summary-header {
|
||||
font-weight: bold;
|
||||
color: var(--borderAccent);
|
||||
}
|
||||
|
||||
/* Error */
|
||||
.error-text {
|
||||
color: var(--error);
|
||||
padding: 0 var(--line-height);
|
||||
}
|
||||
.tool-error {
|
||||
color: var(--error);
|
||||
}
|
||||
|
||||
/* Images */
|
||||
.message-images {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.message-image {
|
||||
max-width: 100%;
|
||||
max-height: 400px;
|
||||
border-radius: 4px;
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
/* Markdown content */
|
||||
.markdown-content h1,
|
||||
.markdown-content h2,
|
||||
.markdown-content h3,
|
||||
.markdown-content h4,
|
||||
.markdown-content h5,
|
||||
.markdown-content h6 {
|
||||
color: var(--mdHeading);
|
||||
margin: var(--line-height) 0 0 0;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown-content h1 { font-size: 1em; }
|
||||
.markdown-content h2 { font-size: 1em; }
|
||||
.markdown-content h3 { font-size: 1em; }
|
||||
.markdown-content h4 { font-size: 1em; }
|
||||
.markdown-content h5 { font-size: 1em; }
|
||||
.markdown-content h6 { font-size: 1em; }
|
||||
.markdown-content p { margin: 0; }
|
||||
.markdown-content p + p { margin-top: var(--line-height); }
|
||||
|
||||
.markdown-content a {
|
||||
color: var(--mdLink);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
background: rgba(128, 128, 128, 0.2);
|
||||
color: var(--mdCode);
|
||||
padding: 0 4px;
|
||||
border-radius: 3px;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
background: transparent;
|
||||
margin: var(--line-height) 0;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-content pre code {
|
||||
display: block;
|
||||
background: none;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 3px solid var(--mdQuoteBorder);
|
||||
padding-left: var(--line-height);
|
||||
margin: var(--line-height) 0;
|
||||
color: var(--mdQuote);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content ul,
|
||||
.markdown-content ol {
|
||||
margin: var(--line-height) 0;
|
||||
padding-left: calc(var(--line-height) * 2);
|
||||
}
|
||||
|
||||
.markdown-content li { margin: 0; }
|
||||
.markdown-content li::marker { color: var(--mdListBullet); }
|
||||
|
||||
.markdown-content hr {
|
||||
border: none;
|
||||
border-top: 1px solid var(--mdHr);
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
border-collapse: collapse;
|
||||
margin: 0.5em 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.markdown-content th,
|
||||
.markdown-content td {
|
||||
border: 1px solid var(--mdCodeBlockBorder);
|
||||
padding: 6px 10px;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
background: rgba(128, 128, 128, 0.1);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown-content img {
|
||||
max-width: 100%;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Syntax highlighting */
|
||||
.hljs { background: transparent; color: var(--text); }
|
||||
.hljs-comment, .hljs-quote { color: var(--syntaxComment); }
|
||||
.hljs-keyword, .hljs-selector-tag { color: var(--syntaxKeyword); }
|
||||
.hljs-number, .hljs-literal { color: var(--syntaxNumber); }
|
||||
.hljs-string, .hljs-doctag { color: var(--syntaxString); }
|
||||
/* Function names: hljs v11 uses .hljs-title.function_ compound class */
|
||||
.hljs-function, .hljs-title, .hljs-title.function_, .hljs-section, .hljs-name { color: var(--syntaxFunction); }
|
||||
/* Types: hljs v11 uses .hljs-title.class_ for class names */
|
||||
.hljs-type, .hljs-class, .hljs-title.class_, .hljs-built_in { color: var(--syntaxType); }
|
||||
.hljs-attr, .hljs-variable, .hljs-variable.language_, .hljs-params, .hljs-property { color: var(--syntaxVariable); }
|
||||
.hljs-meta, .hljs-meta .hljs-keyword, .hljs-meta .hljs-string { color: var(--syntaxKeyword); }
|
||||
.hljs-operator { color: var(--syntaxOperator); }
|
||||
.hljs-punctuation { color: var(--syntaxPunctuation); }
|
||||
.hljs-subst { color: var(--text); }
|
||||
|
||||
/* Footer */
|
||||
.footer {
|
||||
margin-top: 48px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
color: var(--dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
/* Mobile */
|
||||
#hamburger {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 10px;
|
||||
left: 10px;
|
||||
z-index: 100;
|
||||
padding: 3px 8px;
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#hamburger:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
|
||||
|
||||
#sidebar-overlay {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
z-index: 98;
|
||||
}
|
||||
|
||||
@media (max-width: 900px) {
|
||||
#sidebar {
|
||||
position: fixed;
|
||||
left: -400px;
|
||||
width: 400px;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
height: 100vh;
|
||||
z-index: 99;
|
||||
transition: left 0.3s;
|
||||
}
|
||||
|
||||
#sidebar.open {
|
||||
left: 0;
|
||||
}
|
||||
|
||||
#sidebar-overlay.open {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#hamburger {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.sidebar-close {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#content {
|
||||
padding: var(--line-height) 16px;
|
||||
}
|
||||
|
||||
#content > * {
|
||||
max-width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 500px) {
|
||||
#sidebar {
|
||||
width: 100vw;
|
||||
left: -100vw;
|
||||
}
|
||||
}
|
||||
|
||||
@media print {
|
||||
#sidebar, #sidebar-toggle { display: none !important; }
|
||||
body { background: white; color: black; }
|
||||
#content { max-width: none; }
|
||||
}
|
||||
54
packages/coding-agent/src/core/export-html/template.html
Normal file
54
packages/coding-agent/src/core/export-html/template.html
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Session Export</title>
|
||||
<style>
|
||||
{{CSS}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<button id="hamburger" title="Open sidebar"><svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="none"><circle cx="6" cy="6" r="2.5"/><circle cx="6" cy="18" r="2.5"/><circle cx="18" cy="12" r="2.5"/><rect x="5" y="6" width="2" height="12"/><path d="M6 12h10c1 0 2 0 2-2V8"/></svg></button>
|
||||
<div id="sidebar-overlay"></div>
|
||||
<div id="app">
|
||||
<aside id="sidebar">
|
||||
<div class="sidebar-header">
|
||||
<div class="sidebar-controls">
|
||||
<input type="text" class="sidebar-search" id="tree-search" placeholder="Search...">
|
||||
</div>
|
||||
<div class="sidebar-filters">
|
||||
<button class="filter-btn active" data-filter="default" title="Hide settings entries">Default</button>
|
||||
<button class="filter-btn" data-filter="no-tools" title="Default minus tool results">No-tools</button>
|
||||
<button class="filter-btn" data-filter="user-only" title="Only user messages">User</button>
|
||||
<button class="filter-btn" data-filter="labeled-only" title="Only labeled entries">Labeled</button>
|
||||
<button class="filter-btn" data-filter="all" title="Show everything">All</button>
|
||||
<button class="sidebar-close" id="sidebar-close" title="Close">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="tree-container" id="tree-container"></div>
|
||||
<div class="tree-status" id="tree-status"></div>
|
||||
</aside>
|
||||
<main id="content">
|
||||
<div id="header-container"></div>
|
||||
<div id="messages"></div>
|
||||
</main>
|
||||
<div id="image-modal" class="image-modal">
|
||||
<img id="modal-image" src="" alt="">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script id="session-data" type="application/json">{{SESSION_DATA}}</script>
|
||||
|
||||
<!-- Vendored libraries -->
|
||||
<script>{{MARKED_JS}}</script>
|
||||
|
||||
<!-- highlight.js -->
|
||||
<script>{{HIGHLIGHT_JS}}</script>
|
||||
|
||||
<!-- Main application code -->
|
||||
<script>
|
||||
{{JS}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
1831
packages/coding-agent/src/core/export-html/template.js
Normal file
1831
packages/coding-agent/src/core/export-html/template.js
Normal file
File diff suppressed because it is too large
Load diff
112
packages/coding-agent/src/core/export-html/tool-renderer.ts
Normal file
112
packages/coding-agent/src/core/export-html/tool-renderer.ts
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Tool HTML renderer for custom tools in HTML export.
|
||||
*
|
||||
* Renders custom tool calls and results to HTML by invoking their TUI renderers
|
||||
* and converting the ANSI output to HTML.
|
||||
*/
|
||||
|
||||
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
|
||||
import type { Theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ToolDefinition } from "../extensions/types.js";
|
||||
import { ansiLinesToHtml } from "./ansi-to-html.js";
|
||||
|
||||
export interface ToolHtmlRendererDeps {
|
||||
/** Function to look up tool definition by name */
|
||||
getToolDefinition: (name: string) => ToolDefinition | undefined;
|
||||
/** Theme for styling */
|
||||
theme: Theme;
|
||||
/** Terminal width for rendering (default: 100) */
|
||||
width?: number;
|
||||
}
|
||||
|
||||
export interface ToolHtmlRenderer {
|
||||
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderCall(toolName: string, args: unknown): string | undefined;
|
||||
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool HTML renderer.
|
||||
*
|
||||
* The renderer looks up tool definitions and invokes their renderCall/renderResult
|
||||
* methods, converting the resulting TUI Component output (ANSI) to HTML.
|
||||
*/
|
||||
export function createToolHtmlRenderer(
|
||||
deps: ToolHtmlRendererDeps,
|
||||
): ToolHtmlRenderer {
|
||||
const { getToolDefinition, theme, width = 100 } = deps;
|
||||
|
||||
return {
|
||||
renderCall(toolName: string, args: unknown): string | undefined {
|
||||
try {
|
||||
const toolDef = getToolDefinition(toolName);
|
||||
if (!toolDef?.renderCall) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const component = toolDef.renderCall(args, theme);
|
||||
if (!component) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = component.render(width);
|
||||
return ansiLinesToHtml(lines);
|
||||
} catch {
|
||||
// On error, return undefined to trigger JSON fallback
|
||||
return undefined;
|
||||
}
|
||||
},
|
||||
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): string | undefined {
|
||||
try {
|
||||
const toolDef = getToolDefinition(toolName);
|
||||
if (!toolDef?.renderResult) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Build AgentToolResult from content array
|
||||
// Cast content since session storage uses generic object types
|
||||
const agentToolResult = {
|
||||
content: result as (TextContent | ImageContent)[],
|
||||
details,
|
||||
isError,
|
||||
};
|
||||
|
||||
// Always render expanded, client-side will apply truncation
|
||||
const component = toolDef.renderResult(
|
||||
agentToolResult,
|
||||
{ expanded: true, isPartial: false },
|
||||
theme,
|
||||
);
|
||||
if (!component) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = component.render(width);
|
||||
return ansiLinesToHtml(lines);
|
||||
} catch {
|
||||
// On error, return undefined to trigger JSON fallback
|
||||
return undefined;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
8426
packages/coding-agent/src/core/export-html/vendor/highlight.min.js
vendored
Normal file
8426
packages/coding-agent/src/core/export-html/vendor/highlight.min.js
vendored
Normal file
File diff suppressed because it is too large
Load diff
1998
packages/coding-agent/src/core/export-html/vendor/marked.min.js
vendored
Normal file
1998
packages/coding-agent/src/core/export-html/vendor/marked.min.js
vendored
Normal file
File diff suppressed because it is too large
Load diff
170
packages/coding-agent/src/core/extensions/index.ts
Normal file
170
packages/coding-agent/src/core/extensions/index.ts
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Extension system for lifecycle events and custom tools.
|
||||
*/
|
||||
|
||||
export type {
|
||||
SlashCommandInfo,
|
||||
SlashCommandLocation,
|
||||
SlashCommandSource,
|
||||
} from "../slash-commands.js";
|
||||
export {
|
||||
createExtensionRuntime,
|
||||
discoverAndLoadExtensions,
|
||||
loadExtensionFromFactory,
|
||||
loadExtensions,
|
||||
} from "./loader.js";
|
||||
export type {
|
||||
ExtensionErrorListener,
|
||||
ForkHandler,
|
||||
NavigateTreeHandler,
|
||||
NewSessionHandler,
|
||||
ShutdownHandler,
|
||||
SwitchSessionHandler,
|
||||
} from "./runner.js";
|
||||
export { ExtensionRunner } from "./runner.js";
|
||||
export type {
|
||||
AgentEndEvent,
|
||||
AgentStartEvent,
|
||||
// Re-exports
|
||||
AgentToolResult,
|
||||
AgentToolUpdateCallback,
|
||||
// App keybindings (for custom editors)
|
||||
AppAction,
|
||||
AppendEntryHandler,
|
||||
// Events - Tool (ToolCallEvent types)
|
||||
BashToolCallEvent,
|
||||
BashToolResultEvent,
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
// Context
|
||||
CompactOptions,
|
||||
// Events - Agent
|
||||
ContextEvent,
|
||||
// Event Results
|
||||
ContextEventResult,
|
||||
ContextUsage,
|
||||
CustomToolCallEvent,
|
||||
CustomToolResultEvent,
|
||||
EditToolCallEvent,
|
||||
EditToolResultEvent,
|
||||
ExecOptions,
|
||||
ExecResult,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
// API
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
// Errors
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFactory,
|
||||
ExtensionFlag,
|
||||
ExtensionHandler,
|
||||
// Runtime
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
ExtensionUIDialogOptions,
|
||||
ExtensionWidgetOptions,
|
||||
FindToolCallEvent,
|
||||
FindToolResultEvent,
|
||||
GetActiveToolsHandler,
|
||||
GetAllToolsHandler,
|
||||
GetCommandsHandler,
|
||||
GetThinkingLevelHandler,
|
||||
GrepToolCallEvent,
|
||||
GrepToolResultEvent,
|
||||
// Events - Input
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
KeybindingsManager,
|
||||
LoadExtensionsResult,
|
||||
LsToolCallEvent,
|
||||
LsToolResultEvent,
|
||||
// Events - Message
|
||||
MessageEndEvent,
|
||||
// Message Rendering
|
||||
MessageRenderer,
|
||||
MessageRenderOptions,
|
||||
MessageStartEvent,
|
||||
MessageUpdateEvent,
|
||||
ModelSelectEvent,
|
||||
ModelSelectSource,
|
||||
// Provider Registration
|
||||
ProviderConfig,
|
||||
ProviderModelConfig,
|
||||
ReadToolCallEvent,
|
||||
ReadToolResultEvent,
|
||||
// Commands
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
// Events - Resources
|
||||
ResourcesDiscoverEvent,
|
||||
ResourcesDiscoverResult,
|
||||
SendMessageHandler,
|
||||
SendUserMessageHandler,
|
||||
SessionBeforeCompactEvent,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeForkEvent,
|
||||
SessionBeforeForkResult,
|
||||
SessionBeforeSwitchEvent,
|
||||
SessionBeforeSwitchResult,
|
||||
SessionBeforeTreeEvent,
|
||||
SessionBeforeTreeResult,
|
||||
SessionCompactEvent,
|
||||
SessionEvent,
|
||||
SessionForkEvent,
|
||||
SessionShutdownEvent,
|
||||
// Events - Session
|
||||
SessionStartEvent,
|
||||
SessionSwitchEvent,
|
||||
SessionTreeEvent,
|
||||
SetActiveToolsHandler,
|
||||
SetLabelHandler,
|
||||
SetModelHandler,
|
||||
SetThinkingLevelHandler,
|
||||
TerminalInputHandler,
|
||||
// Events - Tool
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
// Tools
|
||||
ToolDefinition,
|
||||
// Events - Tool Execution
|
||||
ToolExecutionEndEvent,
|
||||
ToolExecutionStartEvent,
|
||||
ToolExecutionUpdateEvent,
|
||||
ToolInfo,
|
||||
ToolRenderResultOptions,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
TreePreparation,
|
||||
TurnEndEvent,
|
||||
TurnStartEvent,
|
||||
// Events - User Bash
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
WidgetPlacement,
|
||||
WriteToolCallEvent,
|
||||
WriteToolResultEvent,
|
||||
} from "./types.js";
|
||||
// Type guards
|
||||
export {
|
||||
isBashToolResult,
|
||||
isEditToolResult,
|
||||
isFindToolResult,
|
||||
isGrepToolResult,
|
||||
isLsToolResult,
|
||||
isReadToolResult,
|
||||
isToolCallEventType,
|
||||
isWriteToolResult,
|
||||
} from "./types.js";
|
||||
export {
|
||||
wrapRegisteredTool,
|
||||
wrapRegisteredTools,
|
||||
wrapToolsWithExtensions,
|
||||
wrapToolWithExtensions,
|
||||
} from "./wrapper.js";
|
||||
607
packages/coding-agent/src/core/extensions/loader.ts
Normal file
607
packages/coding-agent/src/core/extensions/loader.ts
Normal file
|
|
@ -0,0 +1,607 @@
|
|||
/**
|
||||
* Extension loader - loads TypeScript extension modules using jiti.
|
||||
*
|
||||
* Uses @mariozechner/jiti fork with virtualModules support for compiled Bun binaries.
|
||||
*/
|
||||
|
||||
import * as fs from "node:fs";
|
||||
import { createRequire } from "node:module";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import { createJiti } from "@mariozechner/jiti";
|
||||
import * as _bundledPiAgentCore from "@mariozechner/pi-agent-core";
|
||||
import * as _bundledPiAi from "@mariozechner/pi-ai";
|
||||
import * as _bundledPiAiOauth from "@mariozechner/pi-ai/oauth";
|
||||
import type { KeyId } from "@mariozechner/pi-tui";
|
||||
import * as _bundledPiTui from "@mariozechner/pi-tui";
|
||||
// Static imports of packages that extensions may use.
|
||||
// These MUST be static so Bun bundles them into the compiled binary.
|
||||
// The virtualModules option then makes them available to extensions.
|
||||
import * as _bundledTypebox from "@sinclair/typebox";
|
||||
import { getAgentDir, isBunBinary } from "../../config.js";
|
||||
// NOTE: This import works because loader.ts exports are NOT re-exported from index.ts,
|
||||
// avoiding a circular dependency. Extensions can import from @mariozechner/pi-coding-agent.
|
||||
import * as _bundledPiCodingAgent from "../../index.js";
|
||||
import { createEventBus, type EventBus } from "../event-bus.js";
|
||||
import type { ExecOptions } from "../exec.js";
|
||||
import { execCommand } from "../exec.js";
|
||||
import type {
|
||||
Extension,
|
||||
ExtensionAPI,
|
||||
ExtensionFactory,
|
||||
ExtensionRuntime,
|
||||
LoadExtensionsResult,
|
||||
MessageRenderer,
|
||||
ProviderConfig,
|
||||
RegisteredCommand,
|
||||
ToolDefinition,
|
||||
} from "./types.js";
|
||||
|
||||
/** Modules available to extensions via virtualModules (for compiled Bun binary) */
|
||||
const VIRTUAL_MODULES: Record<string, unknown> = {
|
||||
"@sinclair/typebox": _bundledTypebox,
|
||||
"@mariozechner/pi-agent-core": _bundledPiAgentCore,
|
||||
"@mariozechner/pi-tui": _bundledPiTui,
|
||||
"@mariozechner/pi-ai": _bundledPiAi,
|
||||
"@mariozechner/pi-ai/oauth": _bundledPiAiOauth,
|
||||
"@mariozechner/pi-coding-agent": _bundledPiCodingAgent,
|
||||
};
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
|
||||
/**
|
||||
* Get aliases for jiti (used in Node.js/development mode).
|
||||
* In Bun binary mode, virtualModules is used instead.
|
||||
*/
|
||||
let _aliases: Record<string, string> | null = null;
|
||||
function getAliases(): Record<string, string> {
|
||||
if (_aliases) return _aliases;
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||||
const packageIndex = path.resolve(__dirname, "../..", "index.js");
|
||||
|
||||
const typeboxEntry = require.resolve("@sinclair/typebox");
|
||||
const typeboxRoot = typeboxEntry.replace(
|
||||
/[\\/]build[\\/]cjs[\\/]index\.js$/,
|
||||
"",
|
||||
);
|
||||
|
||||
const packagesRoot = path.resolve(__dirname, "../../../../");
|
||||
const resolveWorkspaceOrImport = (
|
||||
workspaceRelativePath: string,
|
||||
specifier: string,
|
||||
): string => {
|
||||
const workspacePath = path.join(packagesRoot, workspaceRelativePath);
|
||||
if (fs.existsSync(workspacePath)) {
|
||||
return workspacePath;
|
||||
}
|
||||
return fileURLToPath(import.meta.resolve(specifier));
|
||||
};
|
||||
|
||||
_aliases = {
|
||||
"@mariozechner/pi-coding-agent": packageIndex,
|
||||
"@mariozechner/pi-agent-core": resolveWorkspaceOrImport(
|
||||
"agent/dist/index.js",
|
||||
"@mariozechner/pi-agent-core",
|
||||
),
|
||||
"@mariozechner/pi-tui": resolveWorkspaceOrImport(
|
||||
"tui/dist/index.js",
|
||||
"@mariozechner/pi-tui",
|
||||
),
|
||||
"@mariozechner/pi-ai": resolveWorkspaceOrImport(
|
||||
"ai/dist/index.js",
|
||||
"@mariozechner/pi-ai",
|
||||
),
|
||||
"@mariozechner/pi-ai/oauth": resolveWorkspaceOrImport(
|
||||
"ai/dist/oauth.js",
|
||||
"@mariozechner/pi-ai/oauth",
|
||||
),
|
||||
"@sinclair/typebox": typeboxRoot,
|
||||
};
|
||||
|
||||
return _aliases;
|
||||
}
|
||||
|
||||
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
|
||||
|
||||
function normalizeUnicodeSpaces(str: string): string {
|
||||
return str.replace(UNICODE_SPACES, " ");
|
||||
}
|
||||
|
||||
function expandPath(p: string): string {
|
||||
const normalized = normalizeUnicodeSpaces(p);
|
||||
if (normalized.startsWith("~/")) {
|
||||
return path.join(os.homedir(), normalized.slice(2));
|
||||
}
|
||||
if (normalized.startsWith("~")) {
|
||||
return path.join(os.homedir(), normalized.slice(1));
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function resolvePath(extPath: string, cwd: string): string {
|
||||
const expanded = expandPath(extPath);
|
||||
if (path.isAbsolute(expanded)) {
|
||||
return expanded;
|
||||
}
|
||||
return path.resolve(cwd, expanded);
|
||||
}
|
||||
|
||||
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
|
||||
|
||||
/**
|
||||
* Create a runtime with throwing stubs for action methods.
|
||||
* Runner.bindCore() replaces these with real implementations.
|
||||
*/
|
||||
export function createExtensionRuntime(): ExtensionRuntime {
|
||||
const notInitialized = () => {
|
||||
throw new Error(
|
||||
"Extension runtime not initialized. Action methods cannot be called during extension loading.",
|
||||
);
|
||||
};
|
||||
|
||||
const runtime: ExtensionRuntime = {
|
||||
sendMessage: notInitialized,
|
||||
sendUserMessage: notInitialized,
|
||||
appendEntry: notInitialized,
|
||||
setSessionName: notInitialized,
|
||||
getSessionName: notInitialized,
|
||||
setLabel: notInitialized,
|
||||
getActiveTools: notInitialized,
|
||||
getAllTools: notInitialized,
|
||||
setActiveTools: notInitialized,
|
||||
// registerTool() is valid during extension load; refresh is only needed post-bind.
|
||||
refreshTools: () => {},
|
||||
getCommands: notInitialized,
|
||||
setModel: () =>
|
||||
Promise.reject(new Error("Extension runtime not initialized")),
|
||||
getThinkingLevel: notInitialized,
|
||||
setThinkingLevel: notInitialized,
|
||||
flagValues: new Map(),
|
||||
pendingProviderRegistrations: [],
|
||||
// Pre-bind: queue registrations so bindCore() can flush them once the
|
||||
// model registry is available. bindCore() replaces both with direct calls.
|
||||
registerProvider: (name, config) => {
|
||||
runtime.pendingProviderRegistrations.push({ name, config });
|
||||
},
|
||||
unregisterProvider: (name) => {
|
||||
runtime.pendingProviderRegistrations =
|
||||
runtime.pendingProviderRegistrations.filter((r) => r.name !== name);
|
||||
},
|
||||
};
|
||||
|
||||
return runtime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the ExtensionAPI for an extension.
|
||||
* Registration methods write to the extension object.
|
||||
* Action methods delegate to the shared runtime.
|
||||
*/
|
||||
function createExtensionAPI(
|
||||
extension: Extension,
|
||||
runtime: ExtensionRuntime,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
): ExtensionAPI {
|
||||
const api = {
|
||||
// Registration methods - write to extension
|
||||
on(event: string, handler: HandlerFn): void {
|
||||
const list = extension.handlers.get(event) ?? [];
|
||||
list.push(handler);
|
||||
extension.handlers.set(event, list);
|
||||
},
|
||||
|
||||
registerTool(tool: ToolDefinition): void {
|
||||
extension.tools.set(tool.name, {
|
||||
definition: tool,
|
||||
extensionPath: extension.path,
|
||||
});
|
||||
runtime.refreshTools();
|
||||
},
|
||||
|
||||
registerCommand(
|
||||
name: string,
|
||||
options: Omit<RegisteredCommand, "name">,
|
||||
): void {
|
||||
extension.commands.set(name, { name, ...options });
|
||||
},
|
||||
|
||||
registerShortcut(
|
||||
shortcut: KeyId,
|
||||
options: {
|
||||
description?: string;
|
||||
handler: (
|
||||
ctx: import("./types.js").ExtensionContext,
|
||||
) => Promise<void> | void;
|
||||
},
|
||||
): void {
|
||||
extension.shortcuts.set(shortcut, {
|
||||
shortcut,
|
||||
extensionPath: extension.path,
|
||||
...options,
|
||||
});
|
||||
},
|
||||
|
||||
registerFlag(
|
||||
name: string,
|
||||
options: {
|
||||
description?: string;
|
||||
type: "boolean" | "string";
|
||||
default?: boolean | string;
|
||||
},
|
||||
): void {
|
||||
extension.flags.set(name, {
|
||||
name,
|
||||
extensionPath: extension.path,
|
||||
...options,
|
||||
});
|
||||
if (options.default !== undefined && !runtime.flagValues.has(name)) {
|
||||
runtime.flagValues.set(name, options.default);
|
||||
}
|
||||
},
|
||||
|
||||
registerMessageRenderer<T>(
|
||||
customType: string,
|
||||
renderer: MessageRenderer<T>,
|
||||
): void {
|
||||
extension.messageRenderers.set(customType, renderer as MessageRenderer);
|
||||
},
|
||||
|
||||
// Flag access - checks extension registered it, reads from runtime
|
||||
getFlag(name: string): boolean | string | undefined {
|
||||
if (!extension.flags.has(name)) return undefined;
|
||||
return runtime.flagValues.get(name);
|
||||
},
|
||||
|
||||
// Action methods - delegate to shared runtime
|
||||
sendMessage(message, options): void {
|
||||
runtime.sendMessage(message, options);
|
||||
},
|
||||
|
||||
sendUserMessage(content, options): void {
|
||||
runtime.sendUserMessage(content, options);
|
||||
},
|
||||
|
||||
appendEntry(customType: string, data?: unknown): void {
|
||||
runtime.appendEntry(customType, data);
|
||||
},
|
||||
|
||||
setSessionName(name: string): void {
|
||||
runtime.setSessionName(name);
|
||||
},
|
||||
|
||||
getSessionName(): string | undefined {
|
||||
return runtime.getSessionName();
|
||||
},
|
||||
|
||||
setLabel(entryId: string, label: string | undefined): void {
|
||||
runtime.setLabel(entryId, label);
|
||||
},
|
||||
|
||||
exec(command: string, args: string[], options?: ExecOptions) {
|
||||
return execCommand(command, args, options?.cwd ?? cwd, options);
|
||||
},
|
||||
|
||||
getActiveTools(): string[] {
|
||||
return runtime.getActiveTools();
|
||||
},
|
||||
|
||||
getAllTools() {
|
||||
return runtime.getAllTools();
|
||||
},
|
||||
|
||||
setActiveTools(toolNames: string[]): void {
|
||||
runtime.setActiveTools(toolNames);
|
||||
},
|
||||
|
||||
getCommands() {
|
||||
return runtime.getCommands();
|
||||
},
|
||||
|
||||
setModel(model) {
|
||||
return runtime.setModel(model);
|
||||
},
|
||||
|
||||
getThinkingLevel() {
|
||||
return runtime.getThinkingLevel();
|
||||
},
|
||||
|
||||
setThinkingLevel(level) {
|
||||
runtime.setThinkingLevel(level);
|
||||
},
|
||||
|
||||
registerProvider(name: string, config: ProviderConfig) {
|
||||
runtime.registerProvider(name, config);
|
||||
},
|
||||
|
||||
unregisterProvider(name: string) {
|
||||
runtime.unregisterProvider(name);
|
||||
},
|
||||
|
||||
events: eventBus,
|
||||
} as ExtensionAPI;
|
||||
|
||||
return api;
|
||||
}
|
||||
|
||||
async function loadExtensionModule(extensionPath: string) {
|
||||
const jiti = createJiti(import.meta.url, {
|
||||
moduleCache: false,
|
||||
// In Bun binary: use virtualModules for bundled packages (no filesystem resolution)
|
||||
// Also disable tryNative so jiti handles ALL imports (not just the entry point)
|
||||
// In Node.js/dev: use aliases to resolve to node_modules paths
|
||||
...(isBunBinary
|
||||
? { virtualModules: VIRTUAL_MODULES, tryNative: false }
|
||||
: { alias: getAliases() }),
|
||||
});
|
||||
|
||||
const module = await jiti.import(extensionPath, { default: true });
|
||||
const factory = module as ExtensionFactory;
|
||||
return typeof factory !== "function" ? undefined : factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Extension object with empty collections.
|
||||
*/
|
||||
function createExtension(
|
||||
extensionPath: string,
|
||||
resolvedPath: string,
|
||||
): Extension {
|
||||
return {
|
||||
path: extensionPath,
|
||||
resolvedPath,
|
||||
handlers: new Map(),
|
||||
tools: new Map(),
|
||||
messageRenderers: new Map(),
|
||||
commands: new Map(),
|
||||
flags: new Map(),
|
||||
shortcuts: new Map(),
|
||||
};
|
||||
}
|
||||
|
||||
async function loadExtension(
|
||||
extensionPath: string,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
runtime: ExtensionRuntime,
|
||||
): Promise<{ extension: Extension | null; error: string | null }> {
|
||||
const resolvedPath = resolvePath(extensionPath, cwd);
|
||||
|
||||
try {
|
||||
const factory = await loadExtensionModule(resolvedPath);
|
||||
if (!factory) {
|
||||
return {
|
||||
extension: null,
|
||||
error: `Extension does not export a valid factory function: ${extensionPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
const extension = createExtension(extensionPath, resolvedPath);
|
||||
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
|
||||
await factory(api);
|
||||
|
||||
return { extension, error: null };
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
return { extension: null, error: `Failed to load extension: ${message}` };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Extension from an inline factory function.
|
||||
*/
|
||||
export async function loadExtensionFromFactory(
|
||||
factory: ExtensionFactory,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
runtime: ExtensionRuntime,
|
||||
extensionPath = "<inline>",
|
||||
): Promise<Extension> {
|
||||
const extension = createExtension(extensionPath, extensionPath);
|
||||
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
|
||||
await factory(api);
|
||||
return extension;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load extensions from paths.
|
||||
*/
|
||||
export async function loadExtensions(
|
||||
paths: string[],
|
||||
cwd: string,
|
||||
eventBus?: EventBus,
|
||||
): Promise<LoadExtensionsResult> {
|
||||
const extensions: Extension[] = [];
|
||||
const errors: Array<{ path: string; error: string }> = [];
|
||||
const resolvedEventBus = eventBus ?? createEventBus();
|
||||
const runtime = createExtensionRuntime();
|
||||
|
||||
for (const extPath of paths) {
|
||||
const { extension, error } = await loadExtension(
|
||||
extPath,
|
||||
cwd,
|
||||
resolvedEventBus,
|
||||
runtime,
|
||||
);
|
||||
|
||||
if (error) {
|
||||
errors.push({ path: extPath, error });
|
||||
continue;
|
||||
}
|
||||
|
||||
if (extension) {
|
||||
extensions.push(extension);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
extensions,
|
||||
errors,
|
||||
runtime,
|
||||
};
|
||||
}
|
||||
|
||||
interface PiManifest {
|
||||
extensions?: string[];
|
||||
themes?: string[];
|
||||
skills?: string[];
|
||||
prompts?: string[];
|
||||
}
|
||||
|
||||
function readPiManifest(packageJsonPath: string): PiManifest | null {
|
||||
try {
|
||||
const content = fs.readFileSync(packageJsonPath, "utf-8");
|
||||
const pkg = JSON.parse(content);
|
||||
if (pkg.pi && typeof pkg.pi === "object") {
|
||||
return pkg.pi as PiManifest;
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function isExtensionFile(name: string): boolean {
|
||||
return name.endsWith(".ts") || name.endsWith(".js");
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve extension entry points from a directory.
|
||||
*
|
||||
* Checks for:
|
||||
* 1. package.json with "pi.extensions" field -> returns declared paths
|
||||
* 2. index.ts or index.js -> returns the index file
|
||||
*
|
||||
* Returns resolved paths or null if no entry points found.
|
||||
*/
|
||||
function resolveExtensionEntries(dir: string): string[] | null {
|
||||
// Check for package.json with "pi" field first
|
||||
const packageJsonPath = path.join(dir, "package.json");
|
||||
if (fs.existsSync(packageJsonPath)) {
|
||||
const manifest = readPiManifest(packageJsonPath);
|
||||
if (manifest?.extensions?.length) {
|
||||
const entries: string[] = [];
|
||||
for (const extPath of manifest.extensions) {
|
||||
const resolvedExtPath = path.resolve(dir, extPath);
|
||||
if (fs.existsSync(resolvedExtPath)) {
|
||||
entries.push(resolvedExtPath);
|
||||
}
|
||||
}
|
||||
if (entries.length > 0) {
|
||||
return entries;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for index.ts or index.js
|
||||
const indexTs = path.join(dir, "index.ts");
|
||||
const indexJs = path.join(dir, "index.js");
|
||||
if (fs.existsSync(indexTs)) {
|
||||
return [indexTs];
|
||||
}
|
||||
if (fs.existsSync(indexJs)) {
|
||||
return [indexJs];
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover extensions in a directory.
|
||||
*
|
||||
* Discovery rules:
|
||||
* 1. Direct files: `extensions/*.ts` or `*.js` → load
|
||||
* 2. Subdirectory with index: `extensions/* /index.ts` or `index.js` → load
|
||||
* 3. Subdirectory with package.json: `extensions/* /package.json` with "pi" field → load what it declares
|
||||
*
|
||||
* No recursion beyond one level. Complex packages must use package.json manifest.
|
||||
*/
|
||||
function discoverExtensionsInDir(dir: string): string[] {
|
||||
if (!fs.existsSync(dir)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const discovered: string[] = [];
|
||||
|
||||
try {
|
||||
const entries = fs.readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const entryPath = path.join(dir, entry.name);
|
||||
|
||||
// 1. Direct files: *.ts or *.js
|
||||
if (
|
||||
(entry.isFile() || entry.isSymbolicLink()) &&
|
||||
isExtensionFile(entry.name)
|
||||
) {
|
||||
discovered.push(entryPath);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2 & 3. Subdirectories
|
||||
if (entry.isDirectory() || entry.isSymbolicLink()) {
|
||||
const entries = resolveExtensionEntries(entryPath);
|
||||
if (entries) {
|
||||
discovered.push(...entries);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
|
||||
return discovered;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover and load extensions from standard locations.
|
||||
*/
|
||||
export async function discoverAndLoadExtensions(
|
||||
configuredPaths: string[],
|
||||
cwd: string,
|
||||
agentDir: string = getAgentDir(),
|
||||
eventBus?: EventBus,
|
||||
): Promise<LoadExtensionsResult> {
|
||||
const allPaths: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
const addPaths = (paths: string[]) => {
|
||||
for (const p of paths) {
|
||||
const resolved = path.resolve(p);
|
||||
if (!seen.has(resolved)) {
|
||||
seen.add(resolved);
|
||||
allPaths.push(p);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Project-local extensions: cwd/.pi/extensions/
|
||||
const localExtDir = path.join(cwd, ".pi", "extensions");
|
||||
addPaths(discoverExtensionsInDir(localExtDir));
|
||||
|
||||
// 2. Global extensions: agentDir/extensions/
|
||||
const globalExtDir = path.join(agentDir, "extensions");
|
||||
addPaths(discoverExtensionsInDir(globalExtDir));
|
||||
|
||||
// 3. Explicitly configured paths
|
||||
for (const p of configuredPaths) {
|
||||
const resolved = resolvePath(p, cwd);
|
||||
if (fs.existsSync(resolved) && fs.statSync(resolved).isDirectory()) {
|
||||
// Check for package.json with pi manifest or index.ts
|
||||
const entries = resolveExtensionEntries(resolved);
|
||||
if (entries) {
|
||||
addPaths(entries);
|
||||
continue;
|
||||
}
|
||||
// No explicit entries - discover individual files in directory
|
||||
addPaths(discoverExtensionsInDir(resolved));
|
||||
continue;
|
||||
}
|
||||
|
||||
addPaths([resolved]);
|
||||
}
|
||||
|
||||
return loadExtensions(allPaths, cwd, eventBus);
|
||||
}
|
||||
950
packages/coding-agent/src/core/extensions/runner.ts
Normal file
950
packages/coding-agent/src/core/extensions/runner.ts
Normal file
|
|
@ -0,0 +1,950 @@
|
|||
/**
|
||||
* Extension runner - executes extensions and manages their lifecycle.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Model } from "@mariozechner/pi-ai";
|
||||
import type { KeyId } from "@mariozechner/pi-tui";
|
||||
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ResourceDiagnostic } from "../diagnostics.js";
|
||||
import type { KeyAction, KeybindingsConfig } from "../keybindings.js";
|
||||
import type { ModelRegistry } from "../model-registry.js";
|
||||
import type { SessionManager } from "../session-manager.js";
|
||||
import type {
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
CompactOptions,
|
||||
ContextEvent,
|
||||
ContextEventResult,
|
||||
ContextUsage,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFlag,
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
MessageRenderer,
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
ResourcesDiscoverEvent,
|
||||
ResourcesDiscoverResult,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeForkResult,
|
||||
SessionBeforeSwitchResult,
|
||||
SessionBeforeTreeResult,
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
} from "./types.js";
|
||||
|
||||
// Keybindings for these actions cannot be overridden by extensions
|
||||
const RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS: ReadonlyArray<KeyAction> = [
|
||||
"interrupt",
|
||||
"clear",
|
||||
"exit",
|
||||
"suspend",
|
||||
"cycleThinkingLevel",
|
||||
"cycleModelForward",
|
||||
"cycleModelBackward",
|
||||
"selectModel",
|
||||
"expandTools",
|
||||
"toggleThinking",
|
||||
"externalEditor",
|
||||
"followUp",
|
||||
"submit",
|
||||
"selectConfirm",
|
||||
"selectCancel",
|
||||
"copy",
|
||||
"deleteToLineEnd",
|
||||
];
|
||||
|
||||
type BuiltInKeyBindings = Partial<
|
||||
Record<KeyId, { action: KeyAction; restrictOverride: boolean }>
|
||||
>;
|
||||
|
||||
const buildBuiltinKeybindings = (
|
||||
effectiveKeybindings: Required<KeybindingsConfig>,
|
||||
): BuiltInKeyBindings => {
|
||||
const builtinKeybindings = {} as BuiltInKeyBindings;
|
||||
for (const [action, keys] of Object.entries(effectiveKeybindings)) {
|
||||
const keyAction = action as KeyAction;
|
||||
const keyList = Array.isArray(keys) ? keys : [keys];
|
||||
const restrictOverride =
|
||||
RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS.includes(keyAction);
|
||||
for (const key of keyList) {
|
||||
const normalizedKey = key.toLowerCase() as KeyId;
|
||||
builtinKeybindings[normalizedKey] = {
|
||||
action: keyAction,
|
||||
restrictOverride: restrictOverride,
|
||||
};
|
||||
}
|
||||
}
|
||||
return builtinKeybindings;
|
||||
};
|
||||
|
||||
/** Combined result from all before_agent_start handlers */
|
||||
interface BeforeAgentStartCombinedResult {
|
||||
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
|
||||
systemPrompt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Events handled by the generic emit() method.
|
||||
* Events with dedicated emitXxx() methods are excluded for stronger type safety.
|
||||
*/
|
||||
type RunnerEmitEvent = Exclude<
|
||||
ExtensionEvent,
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| UserBashEvent
|
||||
| ContextEvent
|
||||
| BeforeAgentStartEvent
|
||||
| ResourcesDiscoverEvent
|
||||
| InputEvent
|
||||
>;
|
||||
|
||||
type SessionBeforeEvent = Extract<
|
||||
RunnerEmitEvent,
|
||||
{
|
||||
type:
|
||||
| "session_before_switch"
|
||||
| "session_before_fork"
|
||||
| "session_before_compact"
|
||||
| "session_before_tree";
|
||||
}
|
||||
>;
|
||||
|
||||
type SessionBeforeEventResult =
|
||||
| SessionBeforeSwitchResult
|
||||
| SessionBeforeForkResult
|
||||
| SessionBeforeCompactResult
|
||||
| SessionBeforeTreeResult;
|
||||
|
||||
type RunnerEmitResult<TEvent extends RunnerEmitEvent> = TEvent extends {
|
||||
type: "session_before_switch";
|
||||
}
|
||||
? SessionBeforeSwitchResult | undefined
|
||||
: TEvent extends { type: "session_before_fork" }
|
||||
? SessionBeforeForkResult | undefined
|
||||
: TEvent extends { type: "session_before_compact" }
|
||||
? SessionBeforeCompactResult | undefined
|
||||
: TEvent extends { type: "session_before_tree" }
|
||||
? SessionBeforeTreeResult | undefined
|
||||
: undefined;
|
||||
|
||||
export type ExtensionErrorListener = (error: ExtensionError) => void;
|
||||
|
||||
export type NewSessionHandler = (options?: {
|
||||
parentSession?: string;
|
||||
setup?: (sessionManager: SessionManager) => Promise<void>;
|
||||
}) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type NavigateTreeHandler = (
|
||||
targetId: string,
|
||||
options?: {
|
||||
summarize?: boolean;
|
||||
customInstructions?: string;
|
||||
replaceInstructions?: boolean;
|
||||
label?: string;
|
||||
},
|
||||
) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type SwitchSessionHandler = (
|
||||
sessionPath: string,
|
||||
) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type ReloadHandler = () => Promise<void>;
|
||||
|
||||
export type ShutdownHandler = () => void;
|
||||
|
||||
/**
|
||||
* Helper function to emit session_shutdown event to extensions.
|
||||
* Returns true if the event was emitted, false if there were no handlers.
|
||||
*/
|
||||
export async function emitSessionShutdownEvent(
|
||||
extensionRunner: ExtensionRunner | undefined,
|
||||
): Promise<boolean> {
|
||||
if (extensionRunner?.hasHandlers("session_shutdown")) {
|
||||
await extensionRunner.emit({
|
||||
type: "session_shutdown",
|
||||
});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const noOpUIContext: ExtensionUIContext = {
|
||||
select: async () => undefined,
|
||||
confirm: async () => false,
|
||||
input: async () => undefined,
|
||||
notify: () => {},
|
||||
onTerminalInput: () => () => {},
|
||||
setStatus: () => {},
|
||||
setWorkingMessage: () => {},
|
||||
setWidget: () => {},
|
||||
setFooter: () => {},
|
||||
setHeader: () => {},
|
||||
setTitle: () => {},
|
||||
custom: async () => undefined as never,
|
||||
pasteToEditor: () => {},
|
||||
setEditorText: () => {},
|
||||
getEditorText: () => "",
|
||||
editor: async () => undefined,
|
||||
setEditorComponent: () => {},
|
||||
get theme() {
|
||||
return theme;
|
||||
},
|
||||
getAllThemes: () => [],
|
||||
getTheme: () => undefined,
|
||||
setTheme: (_theme: string | Theme) => ({
|
||||
success: false,
|
||||
error: "UI not available",
|
||||
}),
|
||||
getToolsExpanded: () => false,
|
||||
setToolsExpanded: () => {},
|
||||
};
|
||||
|
||||
export class ExtensionRunner {
|
||||
private extensions: Extension[];
|
||||
private runtime: ExtensionRuntime;
|
||||
private uiContext: ExtensionUIContext;
|
||||
private cwd: string;
|
||||
private sessionManager: SessionManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private errorListeners: Set<ExtensionErrorListener> = new Set();
|
||||
private getModel: () => Model<any> | undefined = () => undefined;
|
||||
private isIdleFn: () => boolean = () => true;
|
||||
private waitForIdleFn: () => Promise<void> = async () => {};
|
||||
private abortFn: () => void = () => {};
|
||||
private hasPendingMessagesFn: () => boolean = () => false;
|
||||
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
|
||||
private compactFn: (options?: CompactOptions) => void = () => {};
|
||||
private getSystemPromptFn: () => string = () => "";
|
||||
private newSessionHandler: NewSessionHandler = async () => ({
|
||||
cancelled: false,
|
||||
});
|
||||
private forkHandler: ForkHandler = async () => ({ cancelled: false });
|
||||
private navigateTreeHandler: NavigateTreeHandler = async () => ({
|
||||
cancelled: false,
|
||||
});
|
||||
private switchSessionHandler: SwitchSessionHandler = async () => ({
|
||||
cancelled: false,
|
||||
});
|
||||
private reloadHandler: ReloadHandler = async () => {};
|
||||
private shutdownHandler: ShutdownHandler = () => {};
|
||||
private shortcutDiagnostics: ResourceDiagnostic[] = [];
|
||||
private commandDiagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
constructor(
|
||||
extensions: Extension[],
|
||||
runtime: ExtensionRuntime,
|
||||
cwd: string,
|
||||
sessionManager: SessionManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
) {
|
||||
this.extensions = extensions;
|
||||
this.runtime = runtime;
|
||||
this.uiContext = noOpUIContext;
|
||||
this.cwd = cwd;
|
||||
this.sessionManager = sessionManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
}
|
||||
|
||||
bindCore(
|
||||
actions: ExtensionActions,
|
||||
contextActions: ExtensionContextActions,
|
||||
): void {
|
||||
// Copy actions into the shared runtime (all extension APIs reference this)
|
||||
this.runtime.sendMessage = actions.sendMessage;
|
||||
this.runtime.sendUserMessage = actions.sendUserMessage;
|
||||
this.runtime.appendEntry = actions.appendEntry;
|
||||
this.runtime.setSessionName = actions.setSessionName;
|
||||
this.runtime.getSessionName = actions.getSessionName;
|
||||
this.runtime.setLabel = actions.setLabel;
|
||||
this.runtime.getActiveTools = actions.getActiveTools;
|
||||
this.runtime.getAllTools = actions.getAllTools;
|
||||
this.runtime.setActiveTools = actions.setActiveTools;
|
||||
this.runtime.refreshTools = actions.refreshTools;
|
||||
this.runtime.getCommands = actions.getCommands;
|
||||
this.runtime.setModel = actions.setModel;
|
||||
this.runtime.getThinkingLevel = actions.getThinkingLevel;
|
||||
this.runtime.setThinkingLevel = actions.setThinkingLevel;
|
||||
|
||||
// Context actions (required)
|
||||
this.getModel = contextActions.getModel;
|
||||
this.isIdleFn = contextActions.isIdle;
|
||||
this.abortFn = contextActions.abort;
|
||||
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
|
||||
this.shutdownHandler = contextActions.shutdown;
|
||||
this.getContextUsageFn = contextActions.getContextUsage;
|
||||
this.compactFn = contextActions.compact;
|
||||
this.getSystemPromptFn = contextActions.getSystemPrompt;
|
||||
|
||||
// Flush provider registrations queued during extension loading
|
||||
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
}
|
||||
this.runtime.pendingProviderRegistrations = [];
|
||||
|
||||
// From this point on, provider registration/unregistration takes effect immediately
|
||||
// without requiring a /reload.
|
||||
this.runtime.registerProvider = (name, config) =>
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
this.runtime.unregisterProvider = (name) =>
|
||||
this.modelRegistry.unregisterProvider(name);
|
||||
}
|
||||
|
||||
bindCommandContext(actions?: ExtensionCommandContextActions): void {
|
||||
if (actions) {
|
||||
this.waitForIdleFn = actions.waitForIdle;
|
||||
this.newSessionHandler = actions.newSession;
|
||||
this.forkHandler = actions.fork;
|
||||
this.navigateTreeHandler = actions.navigateTree;
|
||||
this.switchSessionHandler = actions.switchSession;
|
||||
this.reloadHandler = actions.reload;
|
||||
return;
|
||||
}
|
||||
|
||||
this.waitForIdleFn = async () => {};
|
||||
this.newSessionHandler = async () => ({ cancelled: false });
|
||||
this.forkHandler = async () => ({ cancelled: false });
|
||||
this.navigateTreeHandler = async () => ({ cancelled: false });
|
||||
this.switchSessionHandler = async () => ({ cancelled: false });
|
||||
this.reloadHandler = async () => {};
|
||||
}
|
||||
|
||||
setUIContext(uiContext?: ExtensionUIContext): void {
|
||||
this.uiContext = uiContext ?? noOpUIContext;
|
||||
}
|
||||
|
||||
getUIContext(): ExtensionUIContext {
|
||||
return this.uiContext;
|
||||
}
|
||||
|
||||
hasUI(): boolean {
|
||||
return this.uiContext !== noOpUIContext;
|
||||
}
|
||||
|
||||
getExtensionPaths(): string[] {
|
||||
return this.extensions.map((e) => e.path);
|
||||
}
|
||||
|
||||
/** Get all registered tools from all extensions (first registration per name wins). */
|
||||
getAllRegisteredTools(): RegisteredTool[] {
|
||||
const toolsByName = new Map<string, RegisteredTool>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const tool of ext.tools.values()) {
|
||||
if (!toolsByName.has(tool.definition.name)) {
|
||||
toolsByName.set(tool.definition.name, tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Array.from(toolsByName.values());
|
||||
}
|
||||
|
||||
/** Get a tool definition by name. Returns undefined if not found. */
|
||||
getToolDefinition(
|
||||
toolName: string,
|
||||
): RegisteredTool["definition"] | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const tool = ext.tools.get(toolName);
|
||||
if (tool) {
|
||||
return tool.definition;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getFlags(): Map<string, ExtensionFlag> {
|
||||
const allFlags = new Map<string, ExtensionFlag>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const [name, flag] of ext.flags) {
|
||||
if (!allFlags.has(name)) {
|
||||
allFlags.set(name, flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
return allFlags;
|
||||
}
|
||||
|
||||
setFlagValue(name: string, value: boolean | string): void {
|
||||
this.runtime.flagValues.set(name, value);
|
||||
}
|
||||
|
||||
getFlagValues(): Map<string, boolean | string> {
|
||||
return new Map(this.runtime.flagValues);
|
||||
}
|
||||
|
||||
getShortcuts(
|
||||
effectiveKeybindings: Required<KeybindingsConfig>,
|
||||
): Map<KeyId, ExtensionShortcut> {
|
||||
this.shortcutDiagnostics = [];
|
||||
const builtinKeybindings = buildBuiltinKeybindings(effectiveKeybindings);
|
||||
const extensionShortcuts = new Map<KeyId, ExtensionShortcut>();
|
||||
|
||||
const addDiagnostic = (message: string, extensionPath: string) => {
|
||||
this.shortcutDiagnostics.push({
|
||||
type: "warning",
|
||||
message,
|
||||
path: extensionPath,
|
||||
});
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
};
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
for (const [key, shortcut] of ext.shortcuts) {
|
||||
const normalizedKey = key.toLowerCase() as KeyId;
|
||||
|
||||
const builtInKeybinding = builtinKeybindings[normalizedKey];
|
||||
if (builtInKeybinding?.restrictOverride === true) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (builtInKeybinding?.restrictOverride === false) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut conflict: '${key}' is built-in shortcut for ${builtInKeybinding.action} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
}
|
||||
|
||||
const existingExtensionShortcut = extensionShortcuts.get(normalizedKey);
|
||||
if (existingExtensionShortcut) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut conflict: '${key}' registered by both ${existingExtensionShortcut.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
}
|
||||
extensionShortcuts.set(normalizedKey, shortcut);
|
||||
}
|
||||
}
|
||||
return extensionShortcuts;
|
||||
}
|
||||
|
||||
getShortcutDiagnostics(): ResourceDiagnostic[] {
|
||||
return this.shortcutDiagnostics;
|
||||
}
|
||||
|
||||
onError(listener: ExtensionErrorListener): () => void {
|
||||
this.errorListeners.add(listener);
|
||||
return () => this.errorListeners.delete(listener);
|
||||
}
|
||||
|
||||
emitError(error: ExtensionError): void {
|
||||
for (const listener of this.errorListeners) {
|
||||
listener(error);
|
||||
}
|
||||
}
|
||||
|
||||
hasHandlers(eventType: string): boolean {
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get(eventType);
|
||||
if (handlers && handlers.length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
getMessageRenderer(customType: string): MessageRenderer | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const renderer = ext.messageRenderers.get(customType);
|
||||
if (renderer) {
|
||||
return renderer;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getRegisteredCommands(reserved?: Set<string>): RegisteredCommand[] {
|
||||
this.commandDiagnostics = [];
|
||||
|
||||
const commands: RegisteredCommand[] = [];
|
||||
const commandOwners = new Map<string, string>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const command of ext.commands.values()) {
|
||||
if (reserved?.has(command.name)) {
|
||||
const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`;
|
||||
this.commandDiagnostics.push({
|
||||
type: "warning",
|
||||
message,
|
||||
path: ext.path,
|
||||
});
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const existingOwner = commandOwners.get(command.name);
|
||||
if (existingOwner) {
|
||||
const message = `Extension command '${command.name}' from ${ext.path} conflicts with ${existingOwner}. Skipping.`;
|
||||
this.commandDiagnostics.push({
|
||||
type: "warning",
|
||||
message,
|
||||
path: ext.path,
|
||||
});
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
commandOwners.set(command.name, ext.path);
|
||||
commands.push(command);
|
||||
}
|
||||
}
|
||||
return commands;
|
||||
}
|
||||
|
||||
getCommandDiagnostics(): ResourceDiagnostic[] {
|
||||
return this.commandDiagnostics;
|
||||
}
|
||||
|
||||
getRegisteredCommandsWithPaths(): Array<{
|
||||
command: RegisteredCommand;
|
||||
extensionPath: string;
|
||||
}> {
|
||||
const result: Array<{ command: RegisteredCommand; extensionPath: string }> =
|
||||
[];
|
||||
for (const ext of this.extensions) {
|
||||
for (const command of ext.commands.values()) {
|
||||
result.push({ command, extensionPath: ext.path });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
getCommand(name: string): RegisteredCommand | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const command = ext.commands.get(name);
|
||||
if (command) {
|
||||
return command;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request a graceful shutdown. Called by extension tools and event handlers.
|
||||
* The actual shutdown behavior is provided by the mode via bindExtensions().
|
||||
*/
|
||||
shutdown(): void {
|
||||
this.shutdownHandler();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an ExtensionContext for use in event handlers and tool execution.
|
||||
* Context values are resolved at call time, so changes via bindCore/bindUI are reflected.
|
||||
*/
|
||||
createContext(): ExtensionContext {
|
||||
const getModel = this.getModel;
|
||||
return {
|
||||
ui: this.uiContext,
|
||||
hasUI: this.hasUI(),
|
||||
cwd: this.cwd,
|
||||
sessionManager: this.sessionManager,
|
||||
modelRegistry: this.modelRegistry,
|
||||
get model() {
|
||||
return getModel();
|
||||
},
|
||||
isIdle: () => this.isIdleFn(),
|
||||
abort: () => this.abortFn(),
|
||||
hasPendingMessages: () => this.hasPendingMessagesFn(),
|
||||
shutdown: () => this.shutdownHandler(),
|
||||
getContextUsage: () => this.getContextUsageFn(),
|
||||
compact: (options) => this.compactFn(options),
|
||||
getSystemPrompt: () => this.getSystemPromptFn(),
|
||||
};
|
||||
}
|
||||
|
||||
createCommandContext(): ExtensionCommandContext {
|
||||
return {
|
||||
...this.createContext(),
|
||||
waitForIdle: () => this.waitForIdleFn(),
|
||||
newSession: (options) => this.newSessionHandler(options),
|
||||
fork: (entryId) => this.forkHandler(entryId),
|
||||
navigateTree: (targetId, options) =>
|
||||
this.navigateTreeHandler(targetId, options),
|
||||
switchSession: (sessionPath) => this.switchSessionHandler(sessionPath),
|
||||
reload: () => this.reloadHandler(),
|
||||
};
|
||||
}
|
||||
|
||||
private isSessionBeforeEvent(
|
||||
event: RunnerEmitEvent,
|
||||
): event is SessionBeforeEvent {
|
||||
return (
|
||||
event.type === "session_before_switch" ||
|
||||
event.type === "session_before_fork" ||
|
||||
event.type === "session_before_compact" ||
|
||||
event.type === "session_before_tree"
|
||||
);
|
||||
}
|
||||
|
||||
async emit<TEvent extends RunnerEmitEvent>(
|
||||
event: TEvent,
|
||||
): Promise<RunnerEmitResult<TEvent>> {
|
||||
const ctx = this.createContext();
|
||||
let result: SessionBeforeEventResult | undefined;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get(event.type);
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (this.isSessionBeforeEvent(event) && handlerResult) {
|
||||
result = handlerResult as SessionBeforeEventResult;
|
||||
if (result.cancel) {
|
||||
return result as RunnerEmitResult<TEvent>;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: event.type,
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result as RunnerEmitResult<TEvent>;
|
||||
}
|
||||
|
||||
async emitToolResult(
|
||||
event: ToolResultEvent,
|
||||
): Promise<ToolResultEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
const currentEvent: ToolResultEvent = { ...event };
|
||||
let modified = false;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("tool_result");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = (await handler(currentEvent, ctx)) as
|
||||
| ToolResultEventResult
|
||||
| undefined;
|
||||
if (!handlerResult) continue;
|
||||
|
||||
if (handlerResult.content !== undefined) {
|
||||
currentEvent.content = handlerResult.content;
|
||||
modified = true;
|
||||
}
|
||||
if (handlerResult.details !== undefined) {
|
||||
currentEvent.details = handlerResult.details;
|
||||
modified = true;
|
||||
}
|
||||
if (handlerResult.isError !== undefined) {
|
||||
currentEvent.isError = handlerResult.isError;
|
||||
modified = true;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "tool_result",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!modified) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
content: currentEvent.content,
|
||||
details: currentEvent.details,
|
||||
isError: currentEvent.isError,
|
||||
};
|
||||
}
|
||||
|
||||
async emitToolCall(
|
||||
event: ToolCallEvent,
|
||||
): Promise<ToolCallEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
let result: ToolCallEventResult | undefined;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("tool_call");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult) {
|
||||
result = handlerResult as ToolCallEventResult;
|
||||
if (result.block) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async emitUserBash(
|
||||
event: UserBashEvent,
|
||||
): Promise<UserBashEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("user_bash");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
if (handlerResult) {
|
||||
return handlerResult as UserBashEventResult;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "user_bash",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
|
||||
const ctx = this.createContext();
|
||||
let currentMessages = structuredClone(messages);
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("context");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ContextEvent = {
|
||||
type: "context",
|
||||
messages: currentMessages,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult && (handlerResult as ContextEventResult).messages) {
|
||||
currentMessages = (handlerResult as ContextEventResult).messages!;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "context",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
}
|
||||
|
||||
async emitBeforeAgentStart(
|
||||
prompt: string,
|
||||
images: ImageContent[] | undefined,
|
||||
systemPrompt: string,
|
||||
): Promise<BeforeAgentStartCombinedResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
|
||||
let currentSystemPrompt = systemPrompt;
|
||||
let systemPromptModified = false;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("before_agent_start");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: BeforeAgentStartEvent = {
|
||||
type: "before_agent_start",
|
||||
prompt,
|
||||
images,
|
||||
systemPrompt: currentSystemPrompt,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult) {
|
||||
const result = handlerResult as BeforeAgentStartEventResult;
|
||||
if (result.message) {
|
||||
messages.push(result.message);
|
||||
}
|
||||
if (result.systemPrompt !== undefined) {
|
||||
currentSystemPrompt = result.systemPrompt;
|
||||
systemPromptModified = true;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "before_agent_start",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (messages.length > 0 || systemPromptModified) {
|
||||
return {
|
||||
messages: messages.length > 0 ? messages : undefined,
|
||||
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async emitResourcesDiscover(
|
||||
cwd: string,
|
||||
reason: ResourcesDiscoverEvent["reason"],
|
||||
): Promise<{
|
||||
skillPaths: Array<{ path: string; extensionPath: string }>;
|
||||
promptPaths: Array<{ path: string; extensionPath: string }>;
|
||||
themePaths: Array<{ path: string; extensionPath: string }>;
|
||||
}> {
|
||||
const ctx = this.createContext();
|
||||
const skillPaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
const promptPaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
const themePaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("resources_discover");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ResourcesDiscoverEvent = {
|
||||
type: "resources_discover",
|
||||
cwd,
|
||||
reason,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
const result = handlerResult as ResourcesDiscoverResult | undefined;
|
||||
|
||||
if (result?.skillPaths?.length) {
|
||||
skillPaths.push(
|
||||
...result.skillPaths.map((path) => ({
|
||||
path,
|
||||
extensionPath: ext.path,
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (result?.promptPaths?.length) {
|
||||
promptPaths.push(
|
||||
...result.promptPaths.map((path) => ({
|
||||
path,
|
||||
extensionPath: ext.path,
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (result?.themePaths?.length) {
|
||||
themePaths.push(
|
||||
...result.themePaths.map((path) => ({
|
||||
path,
|
||||
extensionPath: ext.path,
|
||||
})),
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "resources_discover",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { skillPaths, promptPaths, themePaths };
|
||||
}
|
||||
|
||||
/** Emit input event. Transforms chain, "handled" short-circuits. */
|
||||
async emitInput(
|
||||
text: string,
|
||||
images: ImageContent[] | undefined,
|
||||
source: InputSource,
|
||||
): Promise<InputEventResult> {
|
||||
const ctx = this.createContext();
|
||||
let currentText = text;
|
||||
let currentImages = images;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
for (const handler of ext.handlers.get("input") ?? []) {
|
||||
try {
|
||||
const event: InputEvent = {
|
||||
type: "input",
|
||||
text: currentText,
|
||||
images: currentImages,
|
||||
source,
|
||||
};
|
||||
const result = (await handler(event, ctx)) as
|
||||
| InputEventResult
|
||||
| undefined;
|
||||
if (result?.action === "handled") return result;
|
||||
if (result?.action === "transform") {
|
||||
currentText = result.text;
|
||||
currentImages = result.images ?? currentImages;
|
||||
}
|
||||
} catch (err) {
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "input",
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
stack: err instanceof Error ? err.stack : undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentText !== text || currentImages !== images
|
||||
? { action: "transform", text: currentText, images: currentImages }
|
||||
: { action: "continue" };
|
||||
}
|
||||
}
|
||||
1575
packages/coding-agent/src/core/extensions/types.ts
Normal file
1575
packages/coding-agent/src/core/extensions/types.ts
Normal file
File diff suppressed because it is too large
Load diff
147
packages/coding-agent/src/core/extensions/wrapper.ts
Normal file
147
packages/coding-agent/src/core/extensions/wrapper.ts
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
/**
|
||||
* Tool wrappers for extensions.
|
||||
*/
|
||||
|
||||
import type {
|
||||
AgentTool,
|
||||
AgentToolUpdateCallback,
|
||||
} from "@mariozechner/pi-agent-core";
|
||||
import type { ExtensionRunner } from "./runner.js";
|
||||
import type { RegisteredTool, ToolCallEventResult } from "./types.js";
|
||||
|
||||
/**
|
||||
* Wrap a RegisteredTool into an AgentTool.
|
||||
* Uses the runner's createContext() for consistent context across tools and event handlers.
|
||||
*/
|
||||
export function wrapRegisteredTool(
|
||||
registeredTool: RegisteredTool,
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool {
|
||||
const { definition } = registeredTool;
|
||||
return {
|
||||
name: definition.name,
|
||||
label: definition.label,
|
||||
description: definition.description,
|
||||
parameters: definition.parameters,
|
||||
execute: (toolCallId, params, signal, onUpdate) =>
|
||||
definition.execute(
|
||||
toolCallId,
|
||||
params,
|
||||
signal,
|
||||
onUpdate,
|
||||
runner.createContext(),
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap all registered tools into AgentTools.
|
||||
* Uses the runner's createContext() for consistent context across tools and event handlers.
|
||||
*/
|
||||
export function wrapRegisteredTools(
|
||||
registeredTools: RegisteredTool[],
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool[] {
|
||||
return registeredTools.map((rt) => wrapRegisteredTool(rt, runner));
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a tool with extension callbacks for interception.
|
||||
* - Emits tool_call event before execution (can block)
|
||||
* - Emits tool_result event after execution (can modify result)
|
||||
*/
|
||||
export function wrapToolWithExtensions<T>(
|
||||
tool: AgentTool<any, T>,
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool<any, T> {
|
||||
return {
|
||||
...tool,
|
||||
execute: async (
|
||||
toolCallId: string,
|
||||
params: Record<string, unknown>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<T>,
|
||||
) => {
|
||||
// Emit tool_call event - extensions can block execution
|
||||
if (runner.hasHandlers("tool_call")) {
|
||||
try {
|
||||
const callResult = (await runner.emitToolCall({
|
||||
type: "tool_call",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
})) as ToolCallEventResult | undefined;
|
||||
|
||||
if (callResult?.block) {
|
||||
const reason =
|
||||
callResult.reason || "Tool execution was blocked by an extension";
|
||||
throw new Error(reason);
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
throw err;
|
||||
}
|
||||
throw new Error(
|
||||
`Extension failed, blocking execution: ${String(err)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the actual tool
|
||||
try {
|
||||
const result = await tool.execute(toolCallId, params, signal, onUpdate);
|
||||
|
||||
// Emit tool_result event - extensions can modify the result
|
||||
if (runner.hasHandlers("tool_result")) {
|
||||
const resultResult = await runner.emitToolResult({
|
||||
type: "tool_result",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError: false,
|
||||
});
|
||||
|
||||
if (resultResult) {
|
||||
return {
|
||||
content: resultResult.content ?? result.content,
|
||||
details: (resultResult.details ?? result.details) as T,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (err) {
|
||||
// Emit tool_result event for errors
|
||||
if (runner.hasHandlers("tool_result")) {
|
||||
await runner.emitToolResult({
|
||||
type: "tool_result",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: err instanceof Error ? err.message : String(err),
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
isError: true,
|
||||
});
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap all tools with extension callbacks.
|
||||
*/
|
||||
export function wrapToolsWithExtensions<T>(
|
||||
tools: AgentTool<any, T>[],
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool<any, T>[] {
|
||||
return tools.map((tool) => wrapToolWithExtensions(tool, runner));
|
||||
}
|
||||
149
packages/coding-agent/src/core/footer-data-provider.ts
Normal file
149
packages/coding-agent/src/core/footer-data-provider.ts
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
import { existsSync, type FSWatcher, readFileSync, statSync, watch } from "fs";
|
||||
import { dirname, join, resolve } from "path";
|
||||
|
||||
/**
|
||||
* Find the git HEAD path by walking up from cwd.
|
||||
* Handles both regular git repos (.git is a directory) and worktrees (.git is a file).
|
||||
*/
|
||||
function findGitHeadPath(): string | null {
|
||||
let dir = process.cwd();
|
||||
while (true) {
|
||||
const gitPath = join(dir, ".git");
|
||||
if (existsSync(gitPath)) {
|
||||
try {
|
||||
const stat = statSync(gitPath);
|
||||
if (stat.isFile()) {
|
||||
const content = readFileSync(gitPath, "utf8").trim();
|
||||
if (content.startsWith("gitdir: ")) {
|
||||
const gitDir = content.slice(8);
|
||||
const headPath = resolve(dir, gitDir, "HEAD");
|
||||
if (existsSync(headPath)) return headPath;
|
||||
}
|
||||
} else if (stat.isDirectory()) {
|
||||
const headPath = join(gitPath, "HEAD");
|
||||
if (existsSync(headPath)) return headPath;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
const parent = dirname(dir);
|
||||
if (parent === dir) return null;
|
||||
dir = parent;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides git branch and extension statuses - data not otherwise accessible to extensions.
|
||||
* Token stats, model info available via ctx.sessionManager and ctx.model.
|
||||
*/
|
||||
export class FooterDataProvider {
|
||||
private extensionStatuses = new Map<string, string>();
|
||||
private cachedBranch: string | null | undefined = undefined;
|
||||
private gitWatcher: FSWatcher | null = null;
|
||||
private branchChangeCallbacks = new Set<() => void>();
|
||||
private availableProviderCount = 0;
|
||||
|
||||
constructor() {
|
||||
this.setupGitWatcher();
|
||||
}
|
||||
|
||||
/** Current git branch, null if not in repo, "detached" if detached HEAD */
|
||||
getGitBranch(): string | null {
|
||||
if (this.cachedBranch !== undefined) return this.cachedBranch;
|
||||
|
||||
try {
|
||||
const gitHeadPath = findGitHeadPath();
|
||||
if (!gitHeadPath) {
|
||||
this.cachedBranch = null;
|
||||
return null;
|
||||
}
|
||||
const content = readFileSync(gitHeadPath, "utf8").trim();
|
||||
this.cachedBranch = content.startsWith("ref: refs/heads/")
|
||||
? content.slice(16)
|
||||
: "detached";
|
||||
} catch {
|
||||
this.cachedBranch = null;
|
||||
}
|
||||
return this.cachedBranch;
|
||||
}
|
||||
|
||||
/** Extension status texts set via ctx.ui.setStatus() */
|
||||
getExtensionStatuses(): ReadonlyMap<string, string> {
|
||||
return this.extensionStatuses;
|
||||
}
|
||||
|
||||
/** Subscribe to git branch changes. Returns unsubscribe function. */
|
||||
onBranchChange(callback: () => void): () => void {
|
||||
this.branchChangeCallbacks.add(callback);
|
||||
return () => this.branchChangeCallbacks.delete(callback);
|
||||
}
|
||||
|
||||
/** Internal: set extension status */
|
||||
setExtensionStatus(key: string, text: string | undefined): void {
|
||||
if (text === undefined) {
|
||||
this.extensionStatuses.delete(key);
|
||||
} else {
|
||||
this.extensionStatuses.set(key, text);
|
||||
}
|
||||
}
|
||||
|
||||
/** Internal: clear extension statuses */
|
||||
clearExtensionStatuses(): void {
|
||||
this.extensionStatuses.clear();
|
||||
}
|
||||
|
||||
/** Number of unique providers with available models (for footer display) */
|
||||
getAvailableProviderCount(): number {
|
||||
return this.availableProviderCount;
|
||||
}
|
||||
|
||||
/** Internal: update available provider count */
|
||||
setAvailableProviderCount(count: number): void {
|
||||
this.availableProviderCount = count;
|
||||
}
|
||||
|
||||
/** Internal: cleanup */
|
||||
dispose(): void {
|
||||
if (this.gitWatcher) {
|
||||
this.gitWatcher.close();
|
||||
this.gitWatcher = null;
|
||||
}
|
||||
this.branchChangeCallbacks.clear();
|
||||
}
|
||||
|
||||
private setupGitWatcher(): void {
|
||||
if (this.gitWatcher) {
|
||||
this.gitWatcher.close();
|
||||
this.gitWatcher = null;
|
||||
}
|
||||
|
||||
const gitHeadPath = findGitHeadPath();
|
||||
if (!gitHeadPath) return;
|
||||
|
||||
// Watch the directory containing HEAD, not HEAD itself.
|
||||
// Git uses atomic writes (write temp, rename over HEAD), which changes the inode.
|
||||
// fs.watch on a file stops working after the inode changes.
|
||||
const gitDir = dirname(gitHeadPath);
|
||||
|
||||
try {
|
||||
this.gitWatcher = watch(gitDir, (_eventType, filename) => {
|
||||
if (filename === "HEAD") {
|
||||
this.cachedBranch = undefined;
|
||||
for (const cb of this.branchChangeCallbacks) cb();
|
||||
}
|
||||
});
|
||||
} catch {
|
||||
// Silently fail if we can't watch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Read-only view for extensions - excludes setExtensionStatus, setAvailableProviderCount and dispose */
|
||||
export type ReadonlyFooterDataProvider = Pick<
|
||||
FooterDataProvider,
|
||||
| "getGitBranch"
|
||||
| "getExtensionStatuses"
|
||||
| "getAvailableProviderCount"
|
||||
| "onBranchChange"
|
||||
>;
|
||||
1290
packages/coding-agent/src/core/gateway-runtime.ts
Normal file
1290
packages/coding-agent/src/core/gateway-runtime.ts
Normal file
File diff suppressed because it is too large
Load diff
70
packages/coding-agent/src/core/index.ts
Normal file
70
packages/coding-agent/src/core/index.ts
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Core modules shared between all run modes.
|
||||
*/
|
||||
|
||||
export {
|
||||
AgentSession,
|
||||
type AgentSessionConfig,
|
||||
type AgentSessionEvent,
|
||||
type AgentSessionEventListener,
|
||||
type ModelCycleResult,
|
||||
type PromptOptions,
|
||||
type SessionStats,
|
||||
} from "./agent-session.js";
|
||||
export {
|
||||
type BashExecutorOptions,
|
||||
type BashResult,
|
||||
executeBash,
|
||||
executeBashWithOperations,
|
||||
} from "./bash-executor.js";
|
||||
export type { CompactionResult } from "./compaction/index.js";
|
||||
export {
|
||||
createEventBus,
|
||||
type EventBus,
|
||||
type EventBusController,
|
||||
} from "./event-bus.js";
|
||||
|
||||
// Extensions system
|
||||
export {
|
||||
type AgentEndEvent,
|
||||
type AgentStartEvent,
|
||||
type AgentToolResult,
|
||||
type AgentToolUpdateCallback,
|
||||
type BeforeAgentStartEvent,
|
||||
type ContextEvent,
|
||||
discoverAndLoadExtensions,
|
||||
type ExecOptions,
|
||||
type ExecResult,
|
||||
type Extension,
|
||||
type ExtensionAPI,
|
||||
type ExtensionCommandContext,
|
||||
type ExtensionContext,
|
||||
type ExtensionError,
|
||||
type ExtensionEvent,
|
||||
type ExtensionFactory,
|
||||
type ExtensionFlag,
|
||||
type ExtensionHandler,
|
||||
ExtensionRunner,
|
||||
type ExtensionShortcut,
|
||||
type ExtensionUIContext,
|
||||
type LoadExtensionsResult,
|
||||
type MessageRenderer,
|
||||
type RegisteredCommand,
|
||||
type SessionBeforeCompactEvent,
|
||||
type SessionBeforeForkEvent,
|
||||
type SessionBeforeSwitchEvent,
|
||||
type SessionBeforeTreeEvent,
|
||||
type SessionCompactEvent,
|
||||
type SessionForkEvent,
|
||||
type SessionShutdownEvent,
|
||||
type SessionStartEvent,
|
||||
type SessionSwitchEvent,
|
||||
type SessionTreeEvent,
|
||||
type ToolCallEvent,
|
||||
type ToolDefinition,
|
||||
type ToolRenderResultOptions,
|
||||
type ToolResultEvent,
|
||||
type TurnEndEvent,
|
||||
type TurnStartEvent,
|
||||
wrapToolsWithExtensions,
|
||||
} from "./extensions/index.js";
|
||||
211
packages/coding-agent/src/core/keybindings.ts
Normal file
211
packages/coding-agent/src/core/keybindings.ts
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
import {
|
||||
DEFAULT_EDITOR_KEYBINDINGS,
|
||||
type EditorAction,
|
||||
type EditorKeybindingsConfig,
|
||||
EditorKeybindingsManager,
|
||||
type KeyId,
|
||||
matchesKey,
|
||||
setEditorKeybindings,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
|
||||
/**
|
||||
* Application-level actions (coding agent specific).
|
||||
*/
|
||||
export type AppAction =
|
||||
| "interrupt"
|
||||
| "clear"
|
||||
| "exit"
|
||||
| "suspend"
|
||||
| "cycleThinkingLevel"
|
||||
| "cycleModelForward"
|
||||
| "cycleModelBackward"
|
||||
| "selectModel"
|
||||
| "expandTools"
|
||||
| "toggleThinking"
|
||||
| "toggleSessionNamedFilter"
|
||||
| "externalEditor"
|
||||
| "followUp"
|
||||
| "dequeue"
|
||||
| "pasteImage"
|
||||
| "newSession"
|
||||
| "tree"
|
||||
| "fork"
|
||||
| "resume";
|
||||
|
||||
/**
|
||||
* All configurable actions.
|
||||
*/
|
||||
export type KeyAction = AppAction | EditorAction;
|
||||
|
||||
/**
|
||||
* Full keybindings configuration (app + editor actions).
|
||||
*/
|
||||
export type KeybindingsConfig = {
|
||||
[K in KeyAction]?: KeyId | KeyId[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Default application keybindings.
|
||||
*/
|
||||
export const DEFAULT_APP_KEYBINDINGS: Record<AppAction, KeyId | KeyId[]> = {
|
||||
interrupt: "escape",
|
||||
clear: "ctrl+c",
|
||||
exit: "ctrl+d",
|
||||
suspend: "ctrl+z",
|
||||
cycleThinkingLevel: "shift+tab",
|
||||
cycleModelForward: "ctrl+p",
|
||||
cycleModelBackward: "shift+ctrl+p",
|
||||
selectModel: "ctrl+l",
|
||||
expandTools: "ctrl+o",
|
||||
toggleThinking: "ctrl+t",
|
||||
toggleSessionNamedFilter: "ctrl+n",
|
||||
externalEditor: "ctrl+g",
|
||||
followUp: "alt+enter",
|
||||
dequeue: "alt+up",
|
||||
pasteImage: process.platform === "win32" ? "alt+v" : "ctrl+v",
|
||||
newSession: [],
|
||||
tree: [],
|
||||
fork: [],
|
||||
resume: [],
|
||||
};
|
||||
|
||||
/**
|
||||
* All default keybindings (app + editor).
|
||||
*/
|
||||
export const DEFAULT_KEYBINDINGS: Required<KeybindingsConfig> = {
|
||||
...DEFAULT_EDITOR_KEYBINDINGS,
|
||||
...DEFAULT_APP_KEYBINDINGS,
|
||||
};
|
||||
|
||||
// App actions list for type checking
|
||||
const APP_ACTIONS: AppAction[] = [
|
||||
"interrupt",
|
||||
"clear",
|
||||
"exit",
|
||||
"suspend",
|
||||
"cycleThinkingLevel",
|
||||
"cycleModelForward",
|
||||
"cycleModelBackward",
|
||||
"selectModel",
|
||||
"expandTools",
|
||||
"toggleThinking",
|
||||
"toggleSessionNamedFilter",
|
||||
"externalEditor",
|
||||
"followUp",
|
||||
"dequeue",
|
||||
"pasteImage",
|
||||
"newSession",
|
||||
"tree",
|
||||
"fork",
|
||||
"resume",
|
||||
];
|
||||
|
||||
function isAppAction(action: string): action is AppAction {
|
||||
return APP_ACTIONS.includes(action as AppAction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages all keybindings (app + editor).
|
||||
*/
|
||||
export class KeybindingsManager {
|
||||
private config: KeybindingsConfig;
|
||||
private appActionToKeys: Map<AppAction, KeyId[]>;
|
||||
|
||||
private constructor(config: KeybindingsConfig) {
|
||||
this.config = config;
|
||||
this.appActionToKeys = new Map();
|
||||
this.buildMaps();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create from config file and set up editor keybindings.
|
||||
*/
|
||||
static create(agentDir: string = getAgentDir()): KeybindingsManager {
|
||||
const configPath = join(agentDir, "keybindings.json");
|
||||
const config = KeybindingsManager.loadFromFile(configPath);
|
||||
const manager = new KeybindingsManager(config);
|
||||
|
||||
// Set up editor keybindings globally
|
||||
// Include both editor actions and expandTools (shared between app and editor)
|
||||
const editorConfig: EditorKeybindingsConfig = {};
|
||||
for (const [action, keys] of Object.entries(config)) {
|
||||
if (!isAppAction(action) || action === "expandTools") {
|
||||
editorConfig[action as EditorAction] = keys;
|
||||
}
|
||||
}
|
||||
setEditorKeybindings(new EditorKeybindingsManager(editorConfig));
|
||||
|
||||
return manager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create in-memory.
|
||||
*/
|
||||
static inMemory(config: KeybindingsConfig = {}): KeybindingsManager {
|
||||
return new KeybindingsManager(config);
|
||||
}
|
||||
|
||||
private static loadFromFile(path: string): KeybindingsConfig {
|
||||
if (!existsSync(path)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(path, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private buildMaps(): void {
|
||||
this.appActionToKeys.clear();
|
||||
|
||||
// Set defaults for app actions
|
||||
for (const [action, keys] of Object.entries(DEFAULT_APP_KEYBINDINGS)) {
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
this.appActionToKeys.set(action as AppAction, [...keyArray]);
|
||||
}
|
||||
|
||||
// Override with user config (app actions only)
|
||||
for (const [action, keys] of Object.entries(this.config)) {
|
||||
if (keys === undefined || !isAppAction(action)) continue;
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
this.appActionToKeys.set(action, keyArray);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if input matches an app action.
|
||||
*/
|
||||
matches(data: string, action: AppAction): boolean {
|
||||
const keys = this.appActionToKeys.get(action);
|
||||
if (!keys) return false;
|
||||
for (const key of keys) {
|
||||
if (matchesKey(data, key)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get keys bound to an app action.
|
||||
*/
|
||||
getKeys(action: AppAction): KeyId[] {
|
||||
return this.appActionToKeys.get(action) ?? [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the full effective config.
|
||||
*/
|
||||
getEffectiveConfig(): Required<KeybindingsConfig> {
|
||||
const result = { ...DEFAULT_KEYBINDINGS };
|
||||
for (const [action, keys] of Object.entries(this.config)) {
|
||||
if (keys !== undefined) {
|
||||
(result as KeybindingsConfig)[action as KeyAction] = keys;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export for convenience
|
||||
export type { EditorAction, KeyId };
|
||||
217
packages/coding-agent/src/core/messages.ts
Normal file
217
packages/coding-agent/src/core/messages.ts
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
/**
|
||||
* Custom message types and transformers for the coding agent.
|
||||
*
|
||||
* Extends the base AgentMessage type with coding-agent specific message types,
|
||||
* and provides a transformer to convert them to LLM-compatible messages.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai";
|
||||
|
||||
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const COMPACTION_SUMMARY_SUFFIX = `
|
||||
</summary>`;
|
||||
|
||||
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
|
||||
|
||||
/**
|
||||
* Message type for bash executions via the ! command.
|
||||
*/
|
||||
export interface BashExecutionMessage {
|
||||
role: "bashExecution";
|
||||
command: string;
|
||||
output: string;
|
||||
exitCode: number | undefined;
|
||||
cancelled: boolean;
|
||||
truncated: boolean;
|
||||
fullOutputPath?: string;
|
||||
timestamp: number;
|
||||
/** If true, this message is excluded from LLM context (!! prefix) */
|
||||
excludeFromContext?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Message type for extension-injected messages via sendMessage().
|
||||
* These are custom messages that extensions can inject into the conversation.
|
||||
*/
|
||||
export interface CustomMessage<T = unknown> {
|
||||
role: "custom";
|
||||
customType: string;
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
display: boolean;
|
||||
details?: T;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface BranchSummaryMessage {
|
||||
role: "branchSummary";
|
||||
summary: string;
|
||||
fromId: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface CompactionSummaryMessage {
|
||||
role: "compactionSummary";
|
||||
summary: string;
|
||||
tokensBefore: number;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
// Extend CustomAgentMessages via declaration merging
|
||||
declare module "@mariozechner/pi-agent-core" {
|
||||
interface CustomAgentMessages {
|
||||
bashExecution: BashExecutionMessage;
|
||||
custom: CustomMessage;
|
||||
branchSummary: BranchSummaryMessage;
|
||||
compactionSummary: CompactionSummaryMessage;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a BashExecutionMessage to user message text for LLM context.
|
||||
*/
|
||||
export function bashExecutionToText(msg: BashExecutionMessage): string {
|
||||
let text = `Ran \`${msg.command}\`\n`;
|
||||
if (msg.output) {
|
||||
text += `\`\`\`\n${msg.output}\n\`\`\``;
|
||||
} else {
|
||||
text += "(no output)";
|
||||
}
|
||||
if (msg.cancelled) {
|
||||
text += "\n\n(command cancelled)";
|
||||
} else if (
|
||||
msg.exitCode !== null &&
|
||||
msg.exitCode !== undefined &&
|
||||
msg.exitCode !== 0
|
||||
) {
|
||||
text += `\n\nCommand exited with code ${msg.exitCode}`;
|
||||
}
|
||||
if (msg.truncated && msg.fullOutputPath) {
|
||||
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
export function createBranchSummaryMessage(
|
||||
summary: string,
|
||||
fromId: string,
|
||||
timestamp: string,
|
||||
): BranchSummaryMessage {
|
||||
return {
|
||||
role: "branchSummary",
|
||||
summary,
|
||||
fromId,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
export function createCompactionSummaryMessage(
|
||||
summary: string,
|
||||
tokensBefore: number,
|
||||
timestamp: string,
|
||||
): CompactionSummaryMessage {
|
||||
return {
|
||||
role: "compactionSummary",
|
||||
summary: summary,
|
||||
tokensBefore,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert CustomMessageEntry to AgentMessage format */
|
||||
export function createCustomMessage(
|
||||
customType: string,
|
||||
content: string | (TextContent | ImageContent)[],
|
||||
display: boolean,
|
||||
details: unknown | undefined,
|
||||
timestamp: string,
|
||||
): CustomMessage {
|
||||
return {
|
||||
role: "custom",
|
||||
customType,
|
||||
content,
|
||||
display,
|
||||
details,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
|
||||
*
|
||||
* This is used by:
|
||||
* - Agent's transormToLlm option (for prompt calls and queued messages)
|
||||
* - Compaction's generateSummary (for summarization)
|
||||
* - Custom extensions and tools
|
||||
*/
|
||||
export function convertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages
|
||||
.map((m): Message | undefined => {
|
||||
switch (m.role) {
|
||||
case "bashExecution":
|
||||
// Skip messages excluded from context (!! prefix)
|
||||
if (m.excludeFromContext) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: bashExecutionToText(m) }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "custom": {
|
||||
const content =
|
||||
typeof m.content === "string"
|
||||
? [{ type: "text" as const, text: m.content }]
|
||||
: m.content;
|
||||
return {
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
case "branchSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX,
|
||||
},
|
||||
],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "compactionSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text:
|
||||
COMPACTION_SUMMARY_PREFIX +
|
||||
m.summary +
|
||||
COMPACTION_SUMMARY_SUFFIX,
|
||||
},
|
||||
],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "user":
|
||||
case "assistant":
|
||||
case "toolResult":
|
||||
return m;
|
||||
default:
|
||||
// biome-ignore lint/correctness/noSwitchDeclarations: fine
|
||||
const _exhaustiveCheck: never = m;
|
||||
return undefined;
|
||||
}
|
||||
})
|
||||
.filter((m) => m !== undefined);
|
||||
}
|
||||
822
packages/coding-agent/src/core/model-registry.ts
Normal file
822
packages/coding-agent/src/core/model-registry.ts
Normal file
|
|
@ -0,0 +1,822 @@
|
|||
/**
|
||||
* Model registry - manages built-in and custom models, provides API key resolution.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
type OAuthProviderInterface,
|
||||
type OpenAICompletionsCompat,
|
||||
type OpenAIResponsesCompat,
|
||||
registerApiProvider,
|
||||
resetApiProviders,
|
||||
type SimpleStreamOptions,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
registerOAuthProvider,
|
||||
resetOAuthProviders,
|
||||
} from "@mariozechner/pi-ai/oauth";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
import {
|
||||
clearConfigValueCache,
|
||||
resolveConfigValue,
|
||||
resolveHeaders,
|
||||
} from "./resolve-config-value.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const ajv = new Ajv();
|
||||
|
||||
// Schema for OpenRouter routing preferences
|
||||
const OpenRouterRoutingSchema = Type.Object({
|
||||
only: Type.Optional(Type.Array(Type.String())),
|
||||
order: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
// Schema for Vercel AI Gateway routing preferences
|
||||
const VercelGatewayRoutingSchema = Type.Object({
|
||||
only: Type.Optional(Type.Array(Type.String())),
|
||||
order: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompletionsCompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("max_completion_tokens"),
|
||||
Type.Literal("max_tokens"),
|
||||
]),
|
||||
),
|
||||
requiresToolResultName: Type.Optional(Type.Boolean()),
|
||||
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
|
||||
requiresThinkingAsText: Type.Optional(Type.Boolean()),
|
||||
requiresMistralToolIds: Type.Optional(Type.Boolean()),
|
||||
thinkingFormat: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai"),
|
||||
Type.Literal("zai"),
|
||||
Type.Literal("qwen"),
|
||||
]),
|
||||
),
|
||||
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
|
||||
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
|
||||
});
|
||||
|
||||
const OpenAIResponsesCompatSchema = Type.Object({
|
||||
// Reserved for future use
|
||||
});
|
||||
|
||||
const OpenAICompatSchema = Type.Union([
|
||||
OpenAICompletionsCompatSchema,
|
||||
OpenAIResponsesCompatSchema,
|
||||
]);
|
||||
|
||||
// Schema for custom model definition
|
||||
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(
|
||||
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
),
|
||||
contextWindow: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
// Schema for per-model overrides (all fields optional, merged with built-in model)
|
||||
const ModelOverrideSchema = Type.Object({
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(
|
||||
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Optional(Type.Number()),
|
||||
output: Type.Optional(Type.Number()),
|
||||
cacheRead: Type.Optional(Type.Number()),
|
||||
cacheWrite: Type.Optional(Type.Number()),
|
||||
}),
|
||||
),
|
||||
contextWindow: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
type ModelOverride = Static<typeof ModelOverrideSchema>;
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
apiKey: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
|
||||
modelOverrides: Type.Optional(
|
||||
Type.Record(Type.String(), ModelOverrideSchema),
|
||||
),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
/** Provider override config (baseUrl, headers, apiKey) without custom models */
|
||||
interface ProviderOverride {
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
apiKey?: string;
|
||||
}
|
||||
|
||||
/** Result of loading custom models from models.json */
|
||||
interface CustomModelsResult {
|
||||
models: Model<Api>[];
|
||||
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
|
||||
overrides: Map<string, ProviderOverride>;
|
||||
/** Per-model overrides: provider -> modelId -> override */
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>;
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
||||
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
|
||||
}
|
||||
|
||||
function mergeCompat(
|
||||
baseCompat: Model<Api>["compat"],
|
||||
overrideCompat: ModelOverride["compat"],
|
||||
): Model<Api>["compat"] | undefined {
|
||||
if (!overrideCompat) return baseCompat;
|
||||
|
||||
const base = baseCompat as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat
|
||||
| undefined;
|
||||
const override = overrideCompat as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat;
|
||||
const merged = { ...base, ...override } as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat;
|
||||
|
||||
const baseCompletions = base as OpenAICompletionsCompat | undefined;
|
||||
const overrideCompletions = override as OpenAICompletionsCompat;
|
||||
const mergedCompletions = merged as OpenAICompletionsCompat;
|
||||
|
||||
if (
|
||||
baseCompletions?.openRouterRouting ||
|
||||
overrideCompletions.openRouterRouting
|
||||
) {
|
||||
mergedCompletions.openRouterRouting = {
|
||||
...baseCompletions?.openRouterRouting,
|
||||
...overrideCompletions.openRouterRouting,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
baseCompletions?.vercelGatewayRouting ||
|
||||
overrideCompletions.vercelGatewayRouting
|
||||
) {
|
||||
mergedCompletions.vercelGatewayRouting = {
|
||||
...baseCompletions?.vercelGatewayRouting,
|
||||
...overrideCompletions.vercelGatewayRouting,
|
||||
};
|
||||
}
|
||||
|
||||
return merged as Model<Api>["compat"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Deep merge a model override into a model.
|
||||
* Handles nested objects (cost, compat) by merging rather than replacing.
|
||||
*/
|
||||
function applyModelOverride(
|
||||
model: Model<Api>,
|
||||
override: ModelOverride,
|
||||
): Model<Api> {
|
||||
const result = { ...model };
|
||||
|
||||
// Simple field overrides
|
||||
if (override.name !== undefined) result.name = override.name;
|
||||
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
|
||||
if (override.input !== undefined)
|
||||
result.input = override.input as ("text" | "image")[];
|
||||
if (override.contextWindow !== undefined)
|
||||
result.contextWindow = override.contextWindow;
|
||||
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
|
||||
|
||||
// Merge cost (partial override)
|
||||
if (override.cost) {
|
||||
result.cost = {
|
||||
input: override.cost.input ?? model.cost.input,
|
||||
output: override.cost.output ?? model.cost.output,
|
||||
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
|
||||
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
|
||||
};
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
if (override.headers) {
|
||||
const resolvedHeaders = resolveHeaders(override.headers);
|
||||
result.headers = resolvedHeaders
|
||||
? { ...model.headers, ...resolvedHeaders }
|
||||
: model.headers;
|
||||
}
|
||||
|
||||
// Deep merge compat
|
||||
result.compat = mergeCompat(model.compat, override.compat);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Clear the config value command cache. Exported for testing. */
|
||||
export const clearApiKeyCache = clearConfigValueCache;
|
||||
|
||||
/**
|
||||
* Model registry - loads and manages models, resolves API keys via AuthStorage.
|
||||
*/
|
||||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||
private loadError: string | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | undefined = join(
|
||||
getAgentDir(),
|
||||
"models.json",
|
||||
),
|
||||
) {
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||
if (keyConfig) {
|
||||
return resolveConfigValue(keyConfig);
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// Load models
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload models from disk (built-in + custom from models.json).
|
||||
*/
|
||||
refresh(): void {
|
||||
this.customProviderApiKeys.clear();
|
||||
this.loadError = undefined;
|
||||
|
||||
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
|
||||
resetApiProviders();
|
||||
resetOAuthProviders();
|
||||
|
||||
this.loadModels();
|
||||
|
||||
for (const [providerName, config] of this.registeredProviders.entries()) {
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any error from loading models.json (undefined if no error).
|
||||
*/
|
||||
getError(): string | undefined {
|
||||
return this.loadError;
|
||||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load custom models and overrides from models.json
|
||||
const {
|
||||
models: customModels,
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error,
|
||||
} = this.modelsJsonPath
|
||||
? this.loadCustomModels(this.modelsJsonPath)
|
||||
: emptyCustomModelsResult();
|
||||
|
||||
if (error) {
|
||||
this.loadError = error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
}
|
||||
|
||||
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
|
||||
let combined = this.mergeCustomModels(builtInModels, customModels);
|
||||
|
||||
// Let OAuth providers modify their models (e.g., update baseUrl)
|
||||
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
|
||||
const cred = this.authStorage.get(oauthProvider.id);
|
||||
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
|
||||
combined = oauthProvider.modifyModels(combined, cred);
|
||||
}
|
||||
}
|
||||
|
||||
this.models = combined;
|
||||
}
|
||||
|
||||
/** Load built-in models and apply provider/model overrides */
|
||||
private loadBuiltInModels(
|
||||
overrides: Map<string, ProviderOverride>,
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>,
|
||||
): Model<Api>[] {
|
||||
return getProviders().flatMap((provider) => {
|
||||
const models = getModels(provider as KnownProvider) as Model<Api>[];
|
||||
const providerOverride = overrides.get(provider);
|
||||
const perModelOverrides = modelOverrides.get(provider);
|
||||
|
||||
return models.map((m) => {
|
||||
let model = m;
|
||||
|
||||
// Apply provider-level baseUrl/headers override
|
||||
if (providerOverride) {
|
||||
const resolvedHeaders = resolveHeaders(providerOverride.headers);
|
||||
model = {
|
||||
...model,
|
||||
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
|
||||
headers: resolvedHeaders
|
||||
? { ...model.headers, ...resolvedHeaders }
|
||||
: model.headers,
|
||||
};
|
||||
}
|
||||
|
||||
// Apply per-model override
|
||||
const modelOverride = perModelOverrides?.get(m.id);
|
||||
if (modelOverride) {
|
||||
model = applyModelOverride(model, modelOverride);
|
||||
}
|
||||
|
||||
return model;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
|
||||
private mergeCustomModels(
|
||||
builtInModels: Model<Api>[],
|
||||
customModels: Model<Api>[],
|
||||
): Model<Api>[] {
|
||||
const merged = [...builtInModels];
|
||||
for (const customModel of customModels) {
|
||||
const existingIndex = merged.findIndex(
|
||||
(m) => m.provider === customModel.provider && m.id === customModel.id,
|
||||
);
|
||||
if (existingIndex >= 0) {
|
||||
merged[existingIndex] = customModel;
|
||||
} else {
|
||||
merged.push(customModel);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return emptyCustomModelsResult();
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const validate = ajv.getSchema("ModelsConfig")!;
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors
|
||||
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
|
||||
.join("\n") || "Unknown schema error";
|
||||
return emptyCustomModelsResult(
|
||||
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
const overrides = new Map<string, ProviderOverride>();
|
||||
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
|
||||
if (
|
||||
providerConfig.baseUrl ||
|
||||
providerConfig.headers ||
|
||||
providerConfig.apiKey
|
||||
) {
|
||||
overrides.set(providerName, {
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
headers: providerConfig.headers,
|
||||
apiKey: providerConfig.apiKey,
|
||||
});
|
||||
}
|
||||
|
||||
// Store API key for fallback resolver.
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
if (providerConfig.modelOverrides) {
|
||||
modelOverrides.set(
|
||||
providerName,
|
||||
new Map(Object.entries(providerConfig.modelOverrides)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
models: this.parseModels(config),
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error: undefined,
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
const models = providerConfig.models ?? [];
|
||||
const hasModelOverrides =
|
||||
providerConfig.modelOverrides &&
|
||||
Object.keys(providerConfig.modelOverrides).length > 0;
|
||||
|
||||
if (models.length === 0) {
|
||||
// Override-only config: needs baseUrl OR modelOverrides (or both)
|
||||
if (!providerConfig.baseUrl && !hasModelOverrides) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Custom models are merged into provider models and require endpoint + auth.
|
||||
if (!providerConfig.baseUrl) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
|
||||
);
|
||||
}
|
||||
if (!providerConfig.apiKey) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (const modelDef of models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!modelDef.id)
|
||||
throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
// Validate contextWindow/maxTokens only if provided (they have defaults)
|
||||
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
|
||||
);
|
||||
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
const modelDefs = providerConfig.models ?? [];
|
||||
if (modelDefs.length === 0) continue; // Override-only, no custom models
|
||||
|
||||
// Store API key config for fallback resolver
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
for (const modelDef of modelDefs) {
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
if (!api) continue;
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
// Resolve env vars and shell commands in header values
|
||||
const providerHeaders = resolveHeaders(providerConfig.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers =
|
||||
providerHeaders || modelHeaders
|
||||
? { ...providerHeaders, ...modelHeaders }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader && providerConfig.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
// Provider baseUrl is required when custom models are defined.
|
||||
// Individual models can override it with modelDef.baseUrl.
|
||||
const defaultCost = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
};
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name ?? modelDef.id,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
|
||||
reasoning: modelDef.reasoning ?? false,
|
||||
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
|
||||
cost: modelDef.cost ?? defaultCost,
|
||||
contextWindow: modelDef.contextWindow ?? 128000,
|
||||
maxTokens: modelDef.maxTokens ?? 16384,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom).
|
||||
* If models.json had errors, returns only built-in models.
|
||||
*/
|
||||
getAll(): Model<Api>[] {
|
||||
return this.models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have auth configured.
|
||||
* This is a fast check that doesn't refresh OAuth tokens.
|
||||
*/
|
||||
getAvailable(): Model<Api>[] {
|
||||
return this.models.filter((m) => this.authStorage.hasAuth(m.provider));
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
*/
|
||||
find(provider: string, modelId: string): Model<Api> | undefined {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
*/
|
||||
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
*/
|
||||
isUsingOAuth(model: Model<Api>): boolean {
|
||||
const cred = this.authStorage.get(model.provider);
|
||||
return cred?.type === "oauth";
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a provider dynamically (from extensions).
|
||||
*
|
||||
* If provider has models: replaces all existing models for this provider.
|
||||
* If provider has only baseUrl/headers: overrides existing models' URLs.
|
||||
* If provider has oauth: registers OAuth provider for /login support.
|
||||
*/
|
||||
registerProvider(providerName: string, config: ProviderConfigInput): void {
|
||||
this.registeredProviders.set(providerName, config);
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister a previously registered provider.
|
||||
*
|
||||
* Removes the provider from the registry and reloads models from disk so that
|
||||
* built-in models overridden by this provider are restored to their original state.
|
||||
* Also resets dynamic OAuth and API stream registrations before reapplying
|
||||
* remaining dynamic providers.
|
||||
* Has no effect if the provider was never registered.
|
||||
*/
|
||||
unregisterProvider(providerName: string): void {
|
||||
if (!this.registeredProviders.has(providerName)) return;
|
||||
this.registeredProviders.delete(providerName);
|
||||
this.customProviderApiKeys.delete(providerName);
|
||||
this.refresh();
|
||||
}
|
||||
|
||||
private applyProviderConfig(
|
||||
providerName: string,
|
||||
config: ProviderConfigInput,
|
||||
): void {
|
||||
// Register OAuth provider if provided
|
||||
if (config.oauth) {
|
||||
// Ensure the OAuth provider ID matches the provider name
|
||||
const oauthProvider: OAuthProviderInterface = {
|
||||
...config.oauth,
|
||||
id: providerName,
|
||||
};
|
||||
registerOAuthProvider(oauthProvider);
|
||||
}
|
||||
|
||||
if (config.streamSimple) {
|
||||
if (!config.api) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "api" is required when registering streamSimple.`,
|
||||
);
|
||||
}
|
||||
const streamSimple = config.streamSimple;
|
||||
registerApiProvider(
|
||||
{
|
||||
api: config.api,
|
||||
stream: (model, context, options) =>
|
||||
streamSimple(model, context, options as SimpleStreamOptions),
|
||||
streamSimple,
|
||||
},
|
||||
`provider:${providerName}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Store API key for auth resolution
|
||||
if (config.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, config.apiKey);
|
||||
}
|
||||
|
||||
if (config.models && config.models.length > 0) {
|
||||
// Full replacement: remove existing models for this provider
|
||||
this.models = this.models.filter((m) => m.provider !== providerName);
|
||||
|
||||
// Validate required fields
|
||||
if (!config.baseUrl) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "baseUrl" is required when defining models.`,
|
||||
);
|
||||
}
|
||||
if (!config.apiKey && !config.oauth) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Parse and add new models
|
||||
for (const modelDef of config.models) {
|
||||
const api = modelDef.api || config.api;
|
||||
if (!api) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
const providerHeaders = resolveHeaders(config.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers =
|
||||
providerHeaders || modelHeaders
|
||||
? { ...providerHeaders, ...modelHeaders }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header
|
||||
if (config.authHeader && config.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(config.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
this.models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: config.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
|
||||
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
|
||||
if (config.oauth?.modifyModels) {
|
||||
const cred = this.authStorage.get(providerName);
|
||||
if (cred?.type === "oauth") {
|
||||
this.models = config.oauth.modifyModels(this.models, cred);
|
||||
}
|
||||
}
|
||||
} else if (config.baseUrl) {
|
||||
// Override-only: update baseUrl/headers for existing models
|
||||
const resolvedHeaders = resolveHeaders(config.headers);
|
||||
this.models = this.models.map((m) => {
|
||||
if (m.provider !== providerName) return m;
|
||||
return {
|
||||
...m,
|
||||
baseUrl: config.baseUrl ?? m.baseUrl,
|
||||
headers: resolvedHeaders
|
||||
? { ...m.headers, ...resolvedHeaders }
|
||||
: m.headers,
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Input type for registerProvider API.
|
||||
*/
|
||||
export interface ProviderConfigInput {
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: Api;
|
||||
streamSimple?: (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
/** OAuth provider for /login support */
|
||||
oauth?: Omit<OAuthProviderInterface, "id">;
|
||||
models?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
api?: Api;
|
||||
baseUrl?: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
compat?: Model<Api>["compat"];
|
||||
}>;
|
||||
}
|
||||
707
packages/coding-agent/src/core/model-resolver.ts
Normal file
707
packages/coding-agent/src/core/model-resolver.ts
Normal file
|
|
@ -0,0 +1,707 @@
|
|||
/**
|
||||
* Model resolution, scoping, and initial selection
|
||||
*/
|
||||
|
||||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import {
|
||||
type Api,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
modelsAreEqual,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { minimatch } from "minimatch";
|
||||
import { isValidThinkingLevel } from "../cli/args.js";
|
||||
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
|
||||
/** Default model IDs for each known provider */
|
||||
export const defaultModelPerProvider: Record<KnownProvider, string> = {
|
||||
"amazon-bedrock": "us.anthropic.claude-opus-4-6-v1",
|
||||
anthropic: "claude-opus-4-6",
|
||||
openai: "gpt-5.4",
|
||||
"azure-openai-responses": "gpt-5.2",
|
||||
"openai-codex": "gpt-5.4",
|
||||
google: "gemini-2.5-pro",
|
||||
"google-gemini-cli": "gemini-2.5-pro",
|
||||
"google-antigravity": "gemini-3.1-pro-high",
|
||||
"google-vertex": "gemini-3-pro-preview",
|
||||
"github-copilot": "gpt-4o",
|
||||
openrouter: "openai/gpt-5.1-codex",
|
||||
"vercel-ai-gateway": "anthropic/claude-opus-4-6",
|
||||
xai: "grok-4-fast-non-reasoning",
|
||||
groq: "openai/gpt-oss-120b",
|
||||
cerebras: "zai-glm-4.6",
|
||||
zai: "glm-4.6",
|
||||
mistral: "devstral-medium-latest",
|
||||
minimax: "MiniMax-M2.1",
|
||||
"minimax-cn": "MiniMax-M2.1",
|
||||
huggingface: "moonshotai/Kimi-K2.5",
|
||||
opencode: "claude-opus-4-6",
|
||||
"opencode-go": "kimi-k2.5",
|
||||
"kimi-coding": "kimi-k2-thinking",
|
||||
};
|
||||
|
||||
export interface ScopedModel {
|
||||
model: Model<Api>;
|
||||
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to check if a model ID looks like an alias (no date suffix)
|
||||
* Dates are typically in format: -20241022 or -20250929
|
||||
*/
|
||||
function isAlias(id: string): boolean {
|
||||
// Check if ID ends with -latest
|
||||
if (id.endsWith("-latest")) return true;
|
||||
|
||||
// Check if ID ends with a date pattern (-YYYYMMDD)
|
||||
const datePattern = /-\d{8}$/;
|
||||
return !datePattern.test(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to match a pattern to a model from the available models list.
|
||||
* Returns the matched model or undefined if no match found.
|
||||
*/
|
||||
function tryMatchModel(
|
||||
modelPattern: string,
|
||||
availableModels: Model<Api>[],
|
||||
): Model<Api> | undefined {
|
||||
// Check for provider/modelId format (provider is everything before the first /)
|
||||
const slashIndex = modelPattern.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const provider = modelPattern.substring(0, slashIndex);
|
||||
const modelId = modelPattern.substring(slashIndex + 1);
|
||||
const providerMatch = availableModels.find(
|
||||
(m) =>
|
||||
m.provider.toLowerCase() === provider.toLowerCase() &&
|
||||
m.id.toLowerCase() === modelId.toLowerCase(),
|
||||
);
|
||||
if (providerMatch) {
|
||||
return providerMatch;
|
||||
}
|
||||
// No exact provider/model match - fall through to other matching
|
||||
}
|
||||
|
||||
// Check for exact ID match (case-insensitive)
|
||||
const exactMatch = availableModels.find(
|
||||
(m) => m.id.toLowerCase() === modelPattern.toLowerCase(),
|
||||
);
|
||||
if (exactMatch) {
|
||||
return exactMatch;
|
||||
}
|
||||
|
||||
// No exact match - fall back to partial matching
|
||||
const matches = availableModels.filter(
|
||||
(m) =>
|
||||
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
|
||||
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
|
||||
);
|
||||
|
||||
if (matches.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Separate into aliases and dated versions
|
||||
const aliases = matches.filter((m) => isAlias(m.id));
|
||||
const datedVersions = matches.filter((m) => !isAlias(m.id));
|
||||
|
||||
if (aliases.length > 0) {
|
||||
// Prefer alias - if multiple aliases, pick the one that sorts highest
|
||||
aliases.sort((a, b) => b.id.localeCompare(a.id));
|
||||
return aliases[0];
|
||||
} else {
|
||||
// No alias found, pick latest dated version
|
||||
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
|
||||
return datedVersions[0];
|
||||
}
|
||||
}
|
||||
|
||||
export interface ParsedModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
/** Thinking level if explicitly specified in pattern, undefined otherwise */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
}
|
||||
|
||||
function buildFallbackModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
availableModels: Model<Api>[],
|
||||
): Model<Api> | undefined {
|
||||
const providerModels = availableModels.filter((m) => m.provider === provider);
|
||||
if (providerModels.length === 0) return undefined;
|
||||
|
||||
const defaultId = defaultModelPerProvider[provider as KnownProvider];
|
||||
const baseModel = defaultId
|
||||
? (providerModels.find((m) => m.id === defaultId) ?? providerModels[0])
|
||||
: providerModels[0];
|
||||
|
||||
return {
|
||||
...baseModel,
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a pattern to extract model and thinking level.
|
||||
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
|
||||
*
|
||||
* Algorithm:
|
||||
* 1. Try to match full pattern as a model
|
||||
* 2. If found, return it with "off" thinking level
|
||||
* 3. If not found and has colons, split on last colon:
|
||||
* - If suffix is valid thinking level, use it and recurse on prefix
|
||||
* - If suffix is invalid, warn and recurse on prefix with "off"
|
||||
*
|
||||
* @internal Exported for testing
|
||||
*/
|
||||
export function parseModelPattern(
|
||||
pattern: string,
|
||||
availableModels: Model<Api>[],
|
||||
options?: { allowInvalidThinkingLevelFallback?: boolean },
|
||||
): ParsedModelResult {
|
||||
// Try exact match first
|
||||
const exactMatch = tryMatchModel(pattern, availableModels);
|
||||
if (exactMatch) {
|
||||
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// No match - try splitting on last colon if present
|
||||
const lastColonIndex = pattern.lastIndexOf(":");
|
||||
if (lastColonIndex === -1) {
|
||||
// No colons, pattern simply doesn't match any model
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
const prefix = pattern.substring(0, lastColonIndex);
|
||||
const suffix = pattern.substring(lastColonIndex + 1);
|
||||
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
// Valid thinking level - recurse on prefix and use this level
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
// Only use this thinking level if no warning from inner recursion
|
||||
return {
|
||||
model: result.model,
|
||||
thinkingLevel: result.warning ? undefined : suffix,
|
||||
warning: result.warning,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
// Invalid suffix
|
||||
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
|
||||
if (!allowFallback) {
|
||||
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
|
||||
// This avoids accidentally resolving to a different model.
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// Scope mode: recurse on prefix and warn
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
return {
|
||||
model: result.model,
|
||||
thinkingLevel: undefined,
|
||||
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve model patterns to actual Model objects with optional thinking levels
|
||||
* Format: "pattern:level" where :level is optional
|
||||
* For each pattern, finds all matching models and picks the best version:
|
||||
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
|
||||
* 2. If no alias, pick the latest dated version
|
||||
*
|
||||
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
|
||||
* The algorithm tries to match the full pattern first, then progressively
|
||||
* strips colon-suffixes to find a match.
|
||||
*/
|
||||
export async function resolveModelScope(
|
||||
patterns: string[],
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<ScopedModel[]> {
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
const scopedModels: ScopedModel[] = [];
|
||||
|
||||
for (const pattern of patterns) {
|
||||
// Check if pattern contains glob characters
|
||||
if (
|
||||
pattern.includes("*") ||
|
||||
pattern.includes("?") ||
|
||||
pattern.includes("[")
|
||||
) {
|
||||
// Extract optional thinking level suffix (e.g., "provider/*:high")
|
||||
const colonIdx = pattern.lastIndexOf(":");
|
||||
let globPattern = pattern;
|
||||
let thinkingLevel: ThinkingLevel | undefined;
|
||||
|
||||
if (colonIdx !== -1) {
|
||||
const suffix = pattern.substring(colonIdx + 1);
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
thinkingLevel = suffix;
|
||||
globPattern = pattern.substring(0, colonIdx);
|
||||
}
|
||||
}
|
||||
|
||||
// Match against "provider/modelId" format OR just model ID
|
||||
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
|
||||
const matchingModels = availableModels.filter((m) => {
|
||||
const fullId = `${m.provider}/${m.id}`;
|
||||
return (
|
||||
minimatch(fullId, globPattern, { nocase: true }) ||
|
||||
minimatch(m.id, globPattern, { nocase: true })
|
||||
);
|
||||
});
|
||||
|
||||
if (matchingModels.length === 0) {
|
||||
console.warn(
|
||||
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const model of matchingModels) {
|
||||
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
||||
scopedModels.push({ model, thinkingLevel });
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(
|
||||
pattern,
|
||||
availableModels,
|
||||
);
|
||||
|
||||
if (warning) {
|
||||
console.warn(chalk.yellow(`Warning: ${warning}`));
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
console.warn(
|
||||
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Avoid duplicates
|
||||
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
||||
scopedModels.push({ model, thinkingLevel });
|
||||
}
|
||||
}
|
||||
|
||||
return scopedModels;
|
||||
}
|
||||
|
||||
export interface ResolveCliModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
/**
|
||||
* Error message suitable for CLI display.
|
||||
* When set, model will be undefined.
|
||||
*/
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a single model from CLI flags.
|
||||
*
|
||||
* Supports:
|
||||
* - --provider <provider> --model <pattern>
|
||||
* - --model <provider>/<pattern>
|
||||
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
|
||||
*
|
||||
* Note: This does not apply the thinking level by itself, but it may *parse* and
|
||||
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
|
||||
*/
|
||||
export function resolveCliModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): ResolveCliModelResult {
|
||||
const { cliProvider, cliModel, modelRegistry } = options;
|
||||
|
||||
if (!cliModel) {
|
||||
return { model: undefined, warning: undefined, error: undefined };
|
||||
}
|
||||
|
||||
// Important: use *all* models here, not just models with pre-configured auth.
|
||||
// This allows "--api-key" to be used for first-time setup.
|
||||
const availableModels = modelRegistry.getAll();
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error:
|
||||
"No models available. Check your installation or add models to models.json.",
|
||||
};
|
||||
}
|
||||
|
||||
// Build canonical provider lookup (case-insensitive)
|
||||
const providerMap = new Map<string, string>();
|
||||
for (const m of availableModels) {
|
||||
providerMap.set(m.provider.toLowerCase(), m.provider);
|
||||
}
|
||||
|
||||
let provider = cliProvider
|
||||
? providerMap.get(cliProvider.toLowerCase())
|
||||
: undefined;
|
||||
if (cliProvider && !provider) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
|
||||
};
|
||||
}
|
||||
|
||||
// If no explicit --provider, try to interpret "provider/model" format first.
|
||||
// When the prefix before the first slash matches a known provider, prefer that
|
||||
// interpretation over matching models whose IDs literally contain slashes
|
||||
// (e.g. "zai/glm-5" should resolve to provider=zai, model=glm-5, not to a
|
||||
// vercel-ai-gateway model with id "zai/glm-5").
|
||||
let pattern = cliModel;
|
||||
let inferredProvider = false;
|
||||
|
||||
if (!provider) {
|
||||
const slashIndex = cliModel.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const maybeProvider = cliModel.substring(0, slashIndex);
|
||||
const canonical = providerMap.get(maybeProvider.toLowerCase());
|
||||
if (canonical) {
|
||||
provider = canonical;
|
||||
pattern = cliModel.substring(slashIndex + 1);
|
||||
inferredProvider = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no provider was inferred from the slash, try exact matches without provider inference.
|
||||
// This handles models whose IDs naturally contain slashes (e.g. OpenRouter-style IDs).
|
||||
if (!provider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) =>
|
||||
m.id.toLowerCase() === lower ||
|
||||
`${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return {
|
||||
model: exact,
|
||||
warning: undefined,
|
||||
thinkingLevel: undefined,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (cliProvider && provider) {
|
||||
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
|
||||
const prefix = `${provider}/`;
|
||||
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
|
||||
pattern = cliModel.substring(prefix.length);
|
||||
}
|
||||
}
|
||||
|
||||
const candidates = provider
|
||||
? availableModels.filter((m) => m.provider === provider)
|
||||
: availableModels;
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(
|
||||
pattern,
|
||||
candidates,
|
||||
{
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
},
|
||||
);
|
||||
|
||||
if (model) {
|
||||
return { model, thinkingLevel, warning, error: undefined };
|
||||
}
|
||||
|
||||
// If we inferred a provider from the slash but found no match within that provider,
|
||||
// fall back to matching the full input as a raw model id across all models.
|
||||
// This handles OpenRouter-style IDs like "openai/gpt-4o:extended" where "openai"
|
||||
// looks like a provider but the full string is actually a model id on openrouter.
|
||||
if (inferredProvider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) =>
|
||||
m.id.toLowerCase() === lower ||
|
||||
`${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return {
|
||||
model: exact,
|
||||
warning: undefined,
|
||||
thinkingLevel: undefined,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
// Also try parseModelPattern on the full input against all models
|
||||
const fallback = parseModelPattern(cliModel, availableModels, {
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
});
|
||||
if (fallback.model) {
|
||||
return {
|
||||
model: fallback.model,
|
||||
thinkingLevel: fallback.thinkingLevel,
|
||||
warning: fallback.warning,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (provider) {
|
||||
const fallbackModel = buildFallbackModel(
|
||||
provider,
|
||||
pattern,
|
||||
availableModels,
|
||||
);
|
||||
if (fallbackModel) {
|
||||
const fallbackWarning = warning
|
||||
? `${warning} Model "${pattern}" not found for provider "${provider}". Using custom model id.`
|
||||
: `Model "${pattern}" not found for provider "${provider}". Using custom model id.`;
|
||||
return {
|
||||
model: fallbackModel,
|
||||
thinkingLevel: undefined,
|
||||
warning: fallbackWarning,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const display = provider ? `${provider}/${pattern}` : cliModel;
|
||||
return {
|
||||
model: undefined,
|
||||
thinkingLevel: undefined,
|
||||
warning,
|
||||
error: `Model "${display}" not found. Use --list-models to see available models.`,
|
||||
};
|
||||
}
|
||||
|
||||
export interface InitialModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
fallbackMessage: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the initial model to use based on priority:
|
||||
* 1. CLI args (provider + model)
|
||||
* 2. First model from scoped models (if not continuing/resuming)
|
||||
* 3. Restored from session (if continuing/resuming)
|
||||
* 4. Saved default from settings
|
||||
* 5. First available model with valid API key
|
||||
*/
|
||||
export async function findInitialModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
scopedModels: ScopedModel[];
|
||||
isContinuing: boolean;
|
||||
defaultProvider?: string;
|
||||
defaultModelId?: string;
|
||||
defaultThinkingLevel?: ThinkingLevel;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): Promise<InitialModelResult> {
|
||||
const {
|
||||
cliProvider,
|
||||
cliModel,
|
||||
scopedModels,
|
||||
isContinuing,
|
||||
defaultProvider,
|
||||
defaultModelId,
|
||||
defaultThinkingLevel,
|
||||
modelRegistry,
|
||||
} = options;
|
||||
|
||||
let model: Model<Api> | undefined;
|
||||
let thinkingLevel: ThinkingLevel = DEFAULT_THINKING_LEVEL;
|
||||
|
||||
// 1. CLI args take priority
|
||||
if (cliProvider && cliModel) {
|
||||
const resolved = resolveCliModel({
|
||||
cliProvider,
|
||||
cliModel,
|
||||
modelRegistry,
|
||||
});
|
||||
if (resolved.error) {
|
||||
console.error(chalk.red(resolved.error));
|
||||
process.exit(1);
|
||||
}
|
||||
if (resolved.model) {
|
||||
return {
|
||||
model: resolved.model,
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Use first model from scoped models (skip if continuing/resuming)
|
||||
if (scopedModels.length > 0 && !isContinuing) {
|
||||
return {
|
||||
model: scopedModels[0].model,
|
||||
thinkingLevel:
|
||||
scopedModels[0].thinkingLevel ??
|
||||
defaultThinkingLevel ??
|
||||
DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// 3. Try saved default from settings
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const found = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (found) {
|
||||
model = found;
|
||||
if (defaultThinkingLevel) {
|
||||
thinkingLevel = defaultThinkingLevel;
|
||||
}
|
||||
return { model, thinkingLevel, fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Try first available model with valid API key
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
for (const provider of Object.keys(
|
||||
defaultModelPerProvider,
|
||||
) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find(
|
||||
(m) => m.provider === provider && m.id === defaultId,
|
||||
);
|
||||
if (match) {
|
||||
return {
|
||||
model: match,
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
return {
|
||||
model: availableModels[0],
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// 5. No model found
|
||||
return {
|
||||
model: undefined,
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Restore model from session, with fallback to available models
|
||||
*/
|
||||
export async function restoreModelFromSession(
|
||||
savedProvider: string,
|
||||
savedModelId: string,
|
||||
currentModel: Model<Api> | undefined,
|
||||
shouldPrintMessages: boolean,
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<{
|
||||
model: Model<Api> | undefined;
|
||||
fallbackMessage: string | undefined;
|
||||
}> {
|
||||
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
||||
|
||||
// Check if restored model exists and has a valid API key
|
||||
const hasApiKey = restoredModel
|
||||
? !!(await modelRegistry.getApiKey(restoredModel))
|
||||
: false;
|
||||
|
||||
if (restoredModel && hasApiKey) {
|
||||
if (shouldPrintMessages) {
|
||||
console.log(
|
||||
chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`),
|
||||
);
|
||||
}
|
||||
return { model: restoredModel, fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// Model not found or no API key - fall back
|
||||
const reason = !restoredModel
|
||||
? "model no longer exists"
|
||||
: "no API key available";
|
||||
|
||||
if (shouldPrintMessages) {
|
||||
console.error(
|
||||
chalk.yellow(
|
||||
`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// If we already have a model, use it as fallback
|
||||
if (currentModel) {
|
||||
if (shouldPrintMessages) {
|
||||
console.log(
|
||||
chalk.dim(
|
||||
`Falling back to: ${currentModel.provider}/${currentModel.id}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
return {
|
||||
model: currentModel,
|
||||
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Try to find any available model
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
let fallbackModel: Model<Api> | undefined;
|
||||
for (const provider of Object.keys(
|
||||
defaultModelPerProvider,
|
||||
) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find(
|
||||
(m) => m.provider === provider && m.id === defaultId,
|
||||
);
|
||||
if (match) {
|
||||
fallbackModel = match;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
if (!fallbackModel) {
|
||||
fallbackModel = availableModels[0];
|
||||
}
|
||||
|
||||
if (shouldPrintMessages) {
|
||||
console.log(
|
||||
chalk.dim(
|
||||
`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
model: fallbackModel,
|
||||
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
|
||||
};
|
||||
}
|
||||
|
||||
// No models available
|
||||
return { model: undefined, fallbackMessage: undefined };
|
||||
}
|
||||
2087
packages/coding-agent/src/core/package-manager.ts
Normal file
2087
packages/coding-agent/src/core/package-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
327
packages/coding-agent/src/core/prompt-templates.ts
Normal file
327
packages/coding-agent/src/core/prompt-templates.ts
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
import { existsSync, readdirSync, readFileSync, statSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { basename, isAbsolute, join, resolve, sep } from "path";
|
||||
import { CONFIG_DIR_NAME, getPromptsDir } from "../config.js";
|
||||
import { parseFrontmatter } from "../utils/frontmatter.js";
|
||||
|
||||
/**
|
||||
* Represents a prompt template loaded from a markdown file
|
||||
*/
|
||||
export interface PromptTemplate {
|
||||
name: string;
|
||||
description: string;
|
||||
content: string;
|
||||
source: string; // "user", "project", or "path"
|
||||
filePath: string; // Absolute path to the template file
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse command arguments respecting quoted strings (bash-style)
|
||||
* Returns array of arguments
|
||||
*/
|
||||
export function parseCommandArgs(argsString: string): string[] {
|
||||
const args: string[] = [];
|
||||
let current = "";
|
||||
let inQuote: string | null = null;
|
||||
|
||||
for (let i = 0; i < argsString.length; i++) {
|
||||
const char = argsString[i];
|
||||
|
||||
if (inQuote) {
|
||||
if (char === inQuote) {
|
||||
inQuote = null;
|
||||
} else {
|
||||
current += char;
|
||||
}
|
||||
} else if (char === '"' || char === "'") {
|
||||
inQuote = char;
|
||||
} else if (char === " " || char === "\t") {
|
||||
if (current) {
|
||||
args.push(current);
|
||||
current = "";
|
||||
}
|
||||
} else {
|
||||
current += char;
|
||||
}
|
||||
}
|
||||
|
||||
if (current) {
|
||||
args.push(current);
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
/**
|
||||
* Substitute argument placeholders in template content
|
||||
* Supports:
|
||||
* - $1, $2, ... for positional args
|
||||
* - $@ and $ARGUMENTS for all args
|
||||
* - ${@:N} for args from Nth onwards (bash-style slicing)
|
||||
* - ${@:N:L} for L args starting from Nth
|
||||
*
|
||||
* Note: Replacement happens on the template string only. Argument values
|
||||
* containing patterns like $1, $@, or $ARGUMENTS are NOT recursively substituted.
|
||||
*/
|
||||
export function substituteArgs(content: string, args: string[]): string {
|
||||
let result = content;
|
||||
|
||||
// Replace $1, $2, etc. with positional args FIRST (before wildcards)
|
||||
// This prevents wildcard replacement values containing $<digit> patterns from being re-substituted
|
||||
result = result.replace(/\$(\d+)/g, (_, num) => {
|
||||
const index = parseInt(num, 10) - 1;
|
||||
return args[index] ?? "";
|
||||
});
|
||||
|
||||
// Replace ${@:start} or ${@:start:length} with sliced args (bash-style)
|
||||
// Process BEFORE simple $@ to avoid conflicts
|
||||
result = result.replace(
|
||||
/\$\{@:(\d+)(?::(\d+))?\}/g,
|
||||
(_, startStr, lengthStr) => {
|
||||
let start = parseInt(startStr, 10) - 1; // Convert to 0-indexed (user provides 1-indexed)
|
||||
// Treat 0 as 1 (bash convention: args start at 1)
|
||||
if (start < 0) start = 0;
|
||||
|
||||
if (lengthStr) {
|
||||
const length = parseInt(lengthStr, 10);
|
||||
return args.slice(start, start + length).join(" ");
|
||||
}
|
||||
return args.slice(start).join(" ");
|
||||
},
|
||||
);
|
||||
|
||||
// Pre-compute all args joined (optimization)
|
||||
const allArgs = args.join(" ");
|
||||
|
||||
// Replace $ARGUMENTS with all args joined (new syntax, aligns with Claude, Codex, OpenCode)
|
||||
result = result.replace(/\$ARGUMENTS/g, allArgs);
|
||||
|
||||
// Replace $@ with all args joined (existing syntax)
|
||||
result = result.replace(/\$@/g, allArgs);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function loadTemplateFromFile(
|
||||
filePath: string,
|
||||
source: string,
|
||||
sourceLabel: string,
|
||||
): PromptTemplate | null {
|
||||
try {
|
||||
const rawContent = readFileSync(filePath, "utf-8");
|
||||
const { frontmatter, body } =
|
||||
parseFrontmatter<Record<string, string>>(rawContent);
|
||||
|
||||
const name = basename(filePath).replace(/\.md$/, "");
|
||||
|
||||
// Get description from frontmatter or first non-empty line
|
||||
let description = frontmatter.description || "";
|
||||
if (!description) {
|
||||
const firstLine = body.split("\n").find((line) => line.trim());
|
||||
if (firstLine) {
|
||||
// Truncate if too long
|
||||
description = firstLine.slice(0, 60);
|
||||
if (firstLine.length > 60) description += "...";
|
||||
}
|
||||
}
|
||||
|
||||
// Append source to description
|
||||
description = description ? `${description} ${sourceLabel}` : sourceLabel;
|
||||
|
||||
return {
|
||||
name,
|
||||
description,
|
||||
content: body,
|
||||
source,
|
||||
filePath,
|
||||
};
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan a directory for .md files (non-recursive) and load them as prompt templates.
|
||||
*/
|
||||
function loadTemplatesFromDir(
|
||||
dir: string,
|
||||
source: string,
|
||||
sourceLabel: string,
|
||||
): PromptTemplate[] {
|
||||
const templates: PromptTemplate[] = [];
|
||||
|
||||
if (!existsSync(dir)) {
|
||||
return templates;
|
||||
}
|
||||
|
||||
try {
|
||||
const entries = readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const fullPath = join(dir, entry.name);
|
||||
|
||||
// For symlinks, check if they point to a file
|
||||
let isFile = entry.isFile();
|
||||
if (entry.isSymbolicLink()) {
|
||||
try {
|
||||
const stats = statSync(fullPath);
|
||||
isFile = stats.isFile();
|
||||
} catch {
|
||||
// Broken symlink, skip it
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFile && entry.name.endsWith(".md")) {
|
||||
const template = loadTemplateFromFile(fullPath, source, sourceLabel);
|
||||
if (template) {
|
||||
templates.push(template);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return templates;
|
||||
}
|
||||
|
||||
return templates;
|
||||
}
|
||||
|
||||
export interface LoadPromptTemplatesOptions {
|
||||
/** Working directory for project-local templates. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Agent config directory for global templates. Default: from getPromptsDir() */
|
||||
agentDir?: string;
|
||||
/** Explicit prompt template paths (files or directories) */
|
||||
promptPaths?: string[];
|
||||
/** Include default prompt directories. Default: true */
|
||||
includeDefaults?: boolean;
|
||||
}
|
||||
|
||||
function normalizePath(input: string): string {
|
||||
const trimmed = input.trim();
|
||||
if (trimmed === "~") return homedir();
|
||||
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
|
||||
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
function resolvePromptPath(p: string, cwd: string): string {
|
||||
const normalized = normalizePath(p);
|
||||
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
|
||||
}
|
||||
|
||||
function buildPathSourceLabel(p: string): string {
|
||||
const base = basename(p).replace(/\.md$/, "") || "path";
|
||||
return `(path:${base})`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all prompt templates from:
|
||||
* 1. Global: agentDir/prompts/
|
||||
* 2. Project: cwd/{CONFIG_DIR_NAME}/prompts/
|
||||
* 3. Explicit prompt paths
|
||||
*/
|
||||
export function loadPromptTemplates(
|
||||
options: LoadPromptTemplatesOptions = {},
|
||||
): PromptTemplate[] {
|
||||
const resolvedCwd = options.cwd ?? process.cwd();
|
||||
const resolvedAgentDir = options.agentDir ?? getPromptsDir();
|
||||
const promptPaths = options.promptPaths ?? [];
|
||||
const includeDefaults = options.includeDefaults ?? true;
|
||||
|
||||
const templates: PromptTemplate[] = [];
|
||||
|
||||
if (includeDefaults) {
|
||||
// 1. Load global templates from agentDir/prompts/
|
||||
// Note: if agentDir is provided, it should be the agent dir, not the prompts dir
|
||||
const globalPromptsDir = options.agentDir
|
||||
? join(options.agentDir, "prompts")
|
||||
: resolvedAgentDir;
|
||||
templates.push(...loadTemplatesFromDir(globalPromptsDir, "user", "(user)"));
|
||||
|
||||
// 2. Load project templates from cwd/{CONFIG_DIR_NAME}/prompts/
|
||||
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
|
||||
templates.push(
|
||||
...loadTemplatesFromDir(projectPromptsDir, "project", "(project)"),
|
||||
);
|
||||
}
|
||||
|
||||
const userPromptsDir = options.agentDir
|
||||
? join(options.agentDir, "prompts")
|
||||
: resolvedAgentDir;
|
||||
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
|
||||
|
||||
const isUnderPath = (target: string, root: string): boolean => {
|
||||
const normalizedRoot = resolve(root);
|
||||
if (target === normalizedRoot) {
|
||||
return true;
|
||||
}
|
||||
const prefix = normalizedRoot.endsWith(sep)
|
||||
? normalizedRoot
|
||||
: `${normalizedRoot}${sep}`;
|
||||
return target.startsWith(prefix);
|
||||
};
|
||||
|
||||
const getSourceInfo = (
|
||||
resolvedPath: string,
|
||||
): { source: string; label: string } => {
|
||||
if (!includeDefaults) {
|
||||
if (isUnderPath(resolvedPath, userPromptsDir)) {
|
||||
return { source: "user", label: "(user)" };
|
||||
}
|
||||
if (isUnderPath(resolvedPath, projectPromptsDir)) {
|
||||
return { source: "project", label: "(project)" };
|
||||
}
|
||||
}
|
||||
return { source: "path", label: buildPathSourceLabel(resolvedPath) };
|
||||
};
|
||||
|
||||
// 3. Load explicit prompt paths
|
||||
for (const rawPath of promptPaths) {
|
||||
const resolvedPath = resolvePromptPath(rawPath, resolvedCwd);
|
||||
if (!existsSync(resolvedPath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = statSync(resolvedPath);
|
||||
const { source, label } = getSourceInfo(resolvedPath);
|
||||
if (stats.isDirectory()) {
|
||||
templates.push(...loadTemplatesFromDir(resolvedPath, source, label));
|
||||
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
|
||||
const template = loadTemplateFromFile(resolvedPath, source, label);
|
||||
if (template) {
|
||||
templates.push(template);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Ignore read failures
|
||||
}
|
||||
}
|
||||
|
||||
return templates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand a prompt template if it matches a template name.
|
||||
* Returns the expanded content or the original text if not a template.
|
||||
*/
|
||||
export function expandPromptTemplate(
|
||||
text: string,
|
||||
templates: PromptTemplate[],
|
||||
): string {
|
||||
if (!text.startsWith("/")) return text;
|
||||
|
||||
const spaceIndex = text.indexOf(" ");
|
||||
const templateName =
|
||||
spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
|
||||
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
|
||||
|
||||
const template = templates.find((t) => t.name === templateName);
|
||||
if (template) {
|
||||
const args = parseCommandArgs(argsString);
|
||||
return substituteArgs(template.content, args);
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
66
packages/coding-agent/src/core/resolve-config-value.ts
Normal file
66
packages/coding-agent/src/core/resolve-config-value.ts
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Resolve configuration values that may be shell commands, environment variables, or literals.
|
||||
* Used by auth-storage.ts and model-registry.ts.
|
||||
*/
|
||||
|
||||
import { execSync } from "child_process";
|
||||
|
||||
// Cache for shell command results (persists for process lifetime)
|
||||
const commandResultCache = new Map<string, string | undefined>();
|
||||
|
||||
/**
|
||||
* Resolve a config value (API key, header value, etc.) to an actual value.
|
||||
* - If starts with "!", executes the rest as a shell command and uses stdout (cached)
|
||||
* - Otherwise checks environment variable first, then treats as literal (not cached)
|
||||
*/
|
||||
export function resolveConfigValue(config: string): string | undefined {
|
||||
if (config.startsWith("!")) {
|
||||
return executeCommand(config);
|
||||
}
|
||||
const envValue = process.env[config];
|
||||
return envValue || config;
|
||||
}
|
||||
|
||||
function executeCommand(commandConfig: string): string | undefined {
|
||||
if (commandResultCache.has(commandConfig)) {
|
||||
return commandResultCache.get(commandConfig);
|
||||
}
|
||||
|
||||
const command = commandConfig.slice(1);
|
||||
let result: string | undefined;
|
||||
try {
|
||||
const output = execSync(command, {
|
||||
encoding: "utf-8",
|
||||
timeout: 10000,
|
||||
stdio: ["ignore", "pipe", "ignore"],
|
||||
});
|
||||
result = output.trim() || undefined;
|
||||
} catch {
|
||||
result = undefined;
|
||||
}
|
||||
|
||||
commandResultCache.set(commandConfig, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve all header values using the same resolution logic as API keys.
|
||||
*/
|
||||
export function resolveHeaders(
|
||||
headers: Record<string, string> | undefined,
|
||||
): Record<string, string> | undefined {
|
||||
if (!headers) return undefined;
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
const resolvedValue = resolveConfigValue(value);
|
||||
if (resolvedValue) {
|
||||
resolved[key] = resolvedValue;
|
||||
}
|
||||
}
|
||||
return Object.keys(resolved).length > 0 ? resolved : undefined;
|
||||
}
|
||||
|
||||
/** Clear the config value command cache. Exported for testing. */
|
||||
export function clearConfigValueCache(): void {
|
||||
commandResultCache.clear();
|
||||
}
|
||||
1094
packages/coding-agent/src/core/resource-loader.ts
Normal file
1094
packages/coding-agent/src/core/resource-loader.ts
Normal file
File diff suppressed because it is too large
Load diff
398
packages/coding-agent/src/core/sdk.ts
Normal file
398
packages/coding-agent/src/core/sdk.ts
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
import { join } from "node:path";
|
||||
import {
|
||||
Agent,
|
||||
type AgentMessage,
|
||||
type ThinkingLevel,
|
||||
} from "@mariozechner/pi-agent-core";
|
||||
import type { Message, Model } from "@mariozechner/pi-ai";
|
||||
import { getAgentDir, getDocsPath } from "../config.js";
|
||||
import { AgentSession } from "./agent-session.js";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
|
||||
import type {
|
||||
ExtensionRunner,
|
||||
LoadExtensionsResult,
|
||||
ToolDefinition,
|
||||
} from "./extensions/index.js";
|
||||
import { convertToLlm } from "./messages.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
import { findInitialModel } from "./model-resolver.js";
|
||||
import type { ResourceLoader } from "./resource-loader.js";
|
||||
import { DefaultResourceLoader } from "./resource-loader.js";
|
||||
import { SessionManager } from "./session-manager.js";
|
||||
import { SettingsManager } from "./settings-manager.js";
|
||||
import { time } from "./timings.js";
|
||||
import {
|
||||
allTools,
|
||||
bashTool,
|
||||
codingTools,
|
||||
createBashTool,
|
||||
createCodingTools,
|
||||
createEditTool,
|
||||
createFindTool,
|
||||
createGrepTool,
|
||||
createLsTool,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createWriteTool,
|
||||
editTool,
|
||||
findTool,
|
||||
grepTool,
|
||||
lsTool,
|
||||
readOnlyTools,
|
||||
readTool,
|
||||
type Tool,
|
||||
type ToolName,
|
||||
writeTool,
|
||||
} from "./tools/index.js";
|
||||
|
||||
export interface CreateAgentSessionOptions {
|
||||
/** Working directory for project-local discovery. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Global config directory. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
|
||||
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
|
||||
authStorage?: AuthStorage;
|
||||
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
|
||||
modelRegistry?: ModelRegistry;
|
||||
|
||||
/** Model to use. Default: from settings, else first available */
|
||||
model?: Model<any>;
|
||||
/** Thinking level. Default: from settings, else 'medium' (clamped to model capabilities) */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
/** Models available for cycling (Ctrl+P in interactive mode) */
|
||||
scopedModels?: Array<{ model: Model<any>; thinkingLevel?: ThinkingLevel }>;
|
||||
|
||||
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
|
||||
tools?: Tool[];
|
||||
/** Custom tools to register (in addition to built-in tools). */
|
||||
customTools?: ToolDefinition[];
|
||||
|
||||
/** Resource loader. When omitted, DefaultResourceLoader is used. */
|
||||
resourceLoader?: ResourceLoader;
|
||||
|
||||
/** Session manager. Default: SessionManager.create(cwd) */
|
||||
sessionManager?: SessionManager;
|
||||
|
||||
/** Settings manager. Default: SettingsManager.create(cwd, agentDir) */
|
||||
settingsManager?: SettingsManager;
|
||||
}
|
||||
|
||||
/** Result from createAgentSession */
|
||||
export interface CreateAgentSessionResult {
|
||||
/** The created session */
|
||||
session: AgentSession;
|
||||
/** Extensions result (for UI context setup in interactive mode) */
|
||||
extensionsResult: LoadExtensionsResult;
|
||||
/** Warning if session was restored with a different model than saved */
|
||||
modelFallbackMessage?: string;
|
||||
}
|
||||
|
||||
// Re-exports
|
||||
|
||||
export type {
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionContext,
|
||||
ExtensionFactory,
|
||||
SlashCommandInfo,
|
||||
SlashCommandLocation,
|
||||
SlashCommandSource,
|
||||
ToolDefinition,
|
||||
} from "./extensions/index.js";
|
||||
export type { PromptTemplate } from "./prompt-templates.js";
|
||||
export type { Skill } from "./skills.js";
|
||||
export type { Tool } from "./tools/index.js";
|
||||
|
||||
export {
|
||||
// Pre-built tools (use process.cwd())
|
||||
readTool,
|
||||
bashTool,
|
||||
editTool,
|
||||
writeTool,
|
||||
grepTool,
|
||||
findTool,
|
||||
lsTool,
|
||||
codingTools,
|
||||
readOnlyTools,
|
||||
allTools as allBuiltInTools,
|
||||
// Tool factories (for custom cwd)
|
||||
createCodingTools,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createBashTool,
|
||||
createEditTool,
|
||||
createWriteTool,
|
||||
createGrepTool,
|
||||
createFindTool,
|
||||
createLsTool,
|
||||
};
|
||||
|
||||
// Helper Functions
|
||||
|
||||
function getDefaultAgentDir(): string {
|
||||
return getAgentDir();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AgentSession with the specified options.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Minimal - uses defaults
|
||||
* const { session } = await createAgentSession();
|
||||
*
|
||||
* // With explicit model
|
||||
* import { getModel } from '@mariozechner/pi-ai';
|
||||
* const { session } = await createAgentSession({
|
||||
* model: getModel('anthropic', 'claude-opus-4-5'),
|
||||
* thinkingLevel: 'high',
|
||||
* });
|
||||
*
|
||||
* // Continue previous session
|
||||
* const { session, modelFallbackMessage } = await createAgentSession({
|
||||
* continueSession: true,
|
||||
* });
|
||||
*
|
||||
* // Full control
|
||||
* const loader = new DefaultResourceLoader({
|
||||
* cwd: process.cwd(),
|
||||
* agentDir: getAgentDir(),
|
||||
* settingsManager: SettingsManager.create(),
|
||||
* });
|
||||
* await loader.reload();
|
||||
* const { session } = await createAgentSession({
|
||||
* model: myModel,
|
||||
* tools: [readTool, bashTool],
|
||||
* resourceLoader: loader,
|
||||
* sessionManager: SessionManager.inMemory(),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export async function createAgentSession(
|
||||
options: CreateAgentSessionOptions = {},
|
||||
): Promise<CreateAgentSessionResult> {
|
||||
const cwd = options.cwd ?? process.cwd();
|
||||
const agentDir = options.agentDir ?? getDefaultAgentDir();
|
||||
let resourceLoader = options.resourceLoader;
|
||||
|
||||
// Use provided or create AuthStorage and ModelRegistry
|
||||
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
|
||||
const modelsPath = options.agentDir
|
||||
? join(agentDir, "models.json")
|
||||
: undefined;
|
||||
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
|
||||
const modelRegistry =
|
||||
options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
|
||||
|
||||
const settingsManager =
|
||||
options.settingsManager ?? SettingsManager.create(cwd, agentDir);
|
||||
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
|
||||
|
||||
if (!resourceLoader) {
|
||||
resourceLoader = new DefaultResourceLoader({
|
||||
cwd,
|
||||
agentDir,
|
||||
settingsManager,
|
||||
});
|
||||
await resourceLoader.reload();
|
||||
time("resourceLoader.reload");
|
||||
}
|
||||
|
||||
// Check if session has existing data to restore
|
||||
const existingSession = sessionManager.buildSessionContext();
|
||||
const hasExistingSession = existingSession.messages.length > 0;
|
||||
const hasThinkingEntry = sessionManager
|
||||
.getBranch()
|
||||
.some((entry) => entry.type === "thinking_level_change");
|
||||
|
||||
let model = options.model;
|
||||
let modelFallbackMessage: string | undefined;
|
||||
|
||||
// If session has data, try to restore model from it
|
||||
if (!model && hasExistingSession && existingSession.model) {
|
||||
const restoredModel = modelRegistry.find(
|
||||
existingSession.model.provider,
|
||||
existingSession.model.modelId,
|
||||
);
|
||||
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
|
||||
model = restoredModel;
|
||||
}
|
||||
if (!model) {
|
||||
modelFallbackMessage = `Could not restore model ${existingSession.model.provider}/${existingSession.model.modelId}`;
|
||||
}
|
||||
}
|
||||
|
||||
// If still no model, use findInitialModel (checks settings default, then provider defaults)
|
||||
if (!model) {
|
||||
const result = await findInitialModel({
|
||||
scopedModels: [],
|
||||
isContinuing: hasExistingSession,
|
||||
defaultProvider: settingsManager.getDefaultProvider(),
|
||||
defaultModelId: settingsManager.getDefaultModel(),
|
||||
defaultThinkingLevel: settingsManager.getDefaultThinkingLevel(),
|
||||
modelRegistry,
|
||||
});
|
||||
model = result.model;
|
||||
if (!model) {
|
||||
modelFallbackMessage = `No models available. Use /login or set an API key environment variable. See ${join(getDocsPath(), "providers.md")}. Then use /model to select a model.`;
|
||||
} else if (modelFallbackMessage) {
|
||||
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
|
||||
}
|
||||
}
|
||||
|
||||
let thinkingLevel = options.thinkingLevel;
|
||||
|
||||
// If session has data, restore thinking level from it
|
||||
if (thinkingLevel === undefined && hasExistingSession) {
|
||||
thinkingLevel = hasThinkingEntry
|
||||
? (existingSession.thinkingLevel as ThinkingLevel)
|
||||
: (settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL);
|
||||
}
|
||||
|
||||
// Fall back to settings default
|
||||
if (thinkingLevel === undefined) {
|
||||
thinkingLevel =
|
||||
settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
|
||||
}
|
||||
|
||||
// Clamp to model capabilities
|
||||
if (!model || !model.reasoning) {
|
||||
thinkingLevel = "off";
|
||||
}
|
||||
|
||||
const defaultActiveToolNames: ToolName[] = ["read", "bash", "edit", "write"];
|
||||
const initialActiveToolNames: ToolName[] = options.tools
|
||||
? options.tools
|
||||
.map((t) => t.name)
|
||||
.filter((n): n is ToolName => n in allTools)
|
||||
: defaultActiveToolNames;
|
||||
|
||||
let agent: Agent;
|
||||
|
||||
// Create convertToLlm wrapper that filters images if blockImages is enabled (defense-in-depth)
|
||||
const convertToLlmWithBlockImages = (messages: AgentMessage[]): Message[] => {
|
||||
const converted = convertToLlm(messages);
|
||||
// Check setting dynamically so mid-session changes take effect
|
||||
if (!settingsManager.getBlockImages()) {
|
||||
return converted;
|
||||
}
|
||||
// Filter out ImageContent from all messages, replacing with text placeholder
|
||||
return converted.map((msg) => {
|
||||
if (msg.role === "user" || msg.role === "toolResult") {
|
||||
const content = msg.content;
|
||||
if (Array.isArray(content)) {
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (hasImages) {
|
||||
const filteredContent = content
|
||||
.map((c) =>
|
||||
c.type === "image"
|
||||
? {
|
||||
type: "text" as const,
|
||||
text: "Image reading is disabled.",
|
||||
}
|
||||
: c,
|
||||
)
|
||||
.filter(
|
||||
(c, i, arr) =>
|
||||
// Dedupe consecutive "Image reading is disabled." texts
|
||||
!(
|
||||
c.type === "text" &&
|
||||
c.text === "Image reading is disabled." &&
|
||||
i > 0 &&
|
||||
arr[i - 1].type === "text" &&
|
||||
(arr[i - 1] as { type: "text"; text: string }).text ===
|
||||
"Image reading is disabled."
|
||||
),
|
||||
);
|
||||
return { ...msg, content: filteredContent };
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
};
|
||||
|
||||
const extensionRunnerRef: { current?: ExtensionRunner } = {};
|
||||
|
||||
agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "",
|
||||
model,
|
||||
thinkingLevel,
|
||||
tools: [],
|
||||
},
|
||||
convertToLlm: convertToLlmWithBlockImages,
|
||||
sessionId: sessionManager.getSessionId(),
|
||||
transformContext: async (messages) => {
|
||||
const runner = extensionRunnerRef.current;
|
||||
if (!runner) return messages;
|
||||
return runner.emitContext(messages);
|
||||
},
|
||||
steeringMode: settingsManager.getSteeringMode(),
|
||||
followUpMode: settingsManager.getFollowUpMode(),
|
||||
transport: settingsManager.getTransport(),
|
||||
thinkingBudgets: settingsManager.getThinkingBudgets(),
|
||||
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
|
||||
getApiKey: async (provider) => {
|
||||
// Use the provider argument from the in-flight request;
|
||||
// agent.state.model may already be switched mid-turn.
|
||||
const resolvedProvider = provider || agent.state.model?.provider;
|
||||
if (!resolvedProvider) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await modelRegistry.getApiKeyForProvider(resolvedProvider);
|
||||
if (!key) {
|
||||
const model = agent.state.model;
|
||||
const isOAuth = model && modelRegistry.isUsingOAuth(model);
|
||||
if (isOAuth) {
|
||||
throw new Error(
|
||||
`Authentication failed for "${resolvedProvider}". ` +
|
||||
`Credentials may have expired or network is unavailable. ` +
|
||||
`Run '/login ${resolvedProvider}' to re-authenticate.`,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`No API key found for "${resolvedProvider}". ` +
|
||||
`Set an API key environment variable or run '/login ${resolvedProvider}'.`,
|
||||
);
|
||||
}
|
||||
return key;
|
||||
},
|
||||
});
|
||||
|
||||
// Restore messages if session has existing data
|
||||
if (hasExistingSession) {
|
||||
agent.replaceMessages(existingSession.messages);
|
||||
if (!hasThinkingEntry) {
|
||||
sessionManager.appendThinkingLevelChange(thinkingLevel);
|
||||
}
|
||||
} else {
|
||||
// Save initial model and thinking level for new sessions so they can be restored on resume
|
||||
if (model) {
|
||||
sessionManager.appendModelChange(model.provider, model.id);
|
||||
}
|
||||
sessionManager.appendThinkingLevelChange(thinkingLevel);
|
||||
}
|
||||
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
cwd,
|
||||
scopedModels: options.scopedModels,
|
||||
resourceLoader,
|
||||
customTools: options.customTools,
|
||||
modelRegistry,
|
||||
initialActiveToolNames,
|
||||
extensionRunnerRef,
|
||||
});
|
||||
const extensionsResult = resourceLoader.getExtensions();
|
||||
|
||||
return {
|
||||
session,
|
||||
extensionsResult,
|
||||
modelFallbackMessage,
|
||||
};
|
||||
}
|
||||
1514
packages/coding-agent/src/core/session-manager.ts
Normal file
1514
packages/coding-agent/src/core/session-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
1057
packages/coding-agent/src/core/settings-manager.ts
Normal file
1057
packages/coding-agent/src/core/settings-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
518
packages/coding-agent/src/core/skills.ts
Normal file
518
packages/coding-agent/src/core/skills.ts
Normal file
|
|
@ -0,0 +1,518 @@
|
|||
import {
|
||||
existsSync,
|
||||
readdirSync,
|
||||
readFileSync,
|
||||
realpathSync,
|
||||
statSync,
|
||||
} from "fs";
|
||||
import ignore from "ignore";
|
||||
import { homedir } from "os";
|
||||
import {
|
||||
basename,
|
||||
dirname,
|
||||
isAbsolute,
|
||||
join,
|
||||
relative,
|
||||
resolve,
|
||||
sep,
|
||||
} from "path";
|
||||
import { CONFIG_DIR_NAME, getAgentDir } from "../config.js";
|
||||
import { parseFrontmatter } from "../utils/frontmatter.js";
|
||||
import type { ResourceDiagnostic } from "./diagnostics.js";
|
||||
|
||||
/** Max name length per spec */
|
||||
const MAX_NAME_LENGTH = 64;
|
||||
|
||||
/** Max description length per spec */
|
||||
const MAX_DESCRIPTION_LENGTH = 1024;
|
||||
|
||||
const IGNORE_FILE_NAMES = [".gitignore", ".ignore", ".fdignore"];
|
||||
|
||||
type IgnoreMatcher = ReturnType<typeof ignore>;
|
||||
|
||||
function toPosixPath(p: string): string {
|
||||
return p.split(sep).join("/");
|
||||
}
|
||||
|
||||
function prefixIgnorePattern(line: string, prefix: string): string | null {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) return null;
|
||||
if (trimmed.startsWith("#") && !trimmed.startsWith("\\#")) return null;
|
||||
|
||||
let pattern = line;
|
||||
let negated = false;
|
||||
|
||||
if (pattern.startsWith("!")) {
|
||||
negated = true;
|
||||
pattern = pattern.slice(1);
|
||||
} else if (pattern.startsWith("\\!")) {
|
||||
pattern = pattern.slice(1);
|
||||
}
|
||||
|
||||
if (pattern.startsWith("/")) {
|
||||
pattern = pattern.slice(1);
|
||||
}
|
||||
|
||||
const prefixed = prefix ? `${prefix}${pattern}` : pattern;
|
||||
return negated ? `!${prefixed}` : prefixed;
|
||||
}
|
||||
|
||||
function addIgnoreRules(ig: IgnoreMatcher, dir: string, rootDir: string): void {
|
||||
const relativeDir = relative(rootDir, dir);
|
||||
const prefix = relativeDir ? `${toPosixPath(relativeDir)}/` : "";
|
||||
|
||||
for (const filename of IGNORE_FILE_NAMES) {
|
||||
const ignorePath = join(dir, filename);
|
||||
if (!existsSync(ignorePath)) continue;
|
||||
try {
|
||||
const content = readFileSync(ignorePath, "utf-8");
|
||||
const patterns = content
|
||||
.split(/\r?\n/)
|
||||
.map((line) => prefixIgnorePattern(line, prefix))
|
||||
.filter((line): line is string => Boolean(line));
|
||||
if (patterns.length > 0) {
|
||||
ig.add(patterns);
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
|
||||
export interface SkillFrontmatter {
|
||||
name?: string;
|
||||
description?: string;
|
||||
"disable-model-invocation"?: boolean;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface Skill {
|
||||
name: string;
|
||||
description: string;
|
||||
filePath: string;
|
||||
baseDir: string;
|
||||
source: string;
|
||||
disableModelInvocation: boolean;
|
||||
}
|
||||
|
||||
export interface LoadSkillsResult {
|
||||
skills: Skill[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate skill name per Agent Skills spec.
|
||||
* Returns array of validation error messages (empty if valid).
|
||||
*/
|
||||
function validateName(name: string, parentDirName: string): string[] {
|
||||
const errors: string[] = [];
|
||||
|
||||
if (name !== parentDirName) {
|
||||
errors.push(
|
||||
`name "${name}" does not match parent directory "${parentDirName}"`,
|
||||
);
|
||||
}
|
||||
|
||||
if (name.length > MAX_NAME_LENGTH) {
|
||||
errors.push(`name exceeds ${MAX_NAME_LENGTH} characters (${name.length})`);
|
||||
}
|
||||
|
||||
if (!/^[a-z0-9-]+$/.test(name)) {
|
||||
errors.push(
|
||||
`name contains invalid characters (must be lowercase a-z, 0-9, hyphens only)`,
|
||||
);
|
||||
}
|
||||
|
||||
if (name.startsWith("-") || name.endsWith("-")) {
|
||||
errors.push(`name must not start or end with a hyphen`);
|
||||
}
|
||||
|
||||
if (name.includes("--")) {
|
||||
errors.push(`name must not contain consecutive hyphens`);
|
||||
}
|
||||
|
||||
return errors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate description per Agent Skills spec.
|
||||
*/
|
||||
function validateDescription(description: string | undefined): string[] {
|
||||
const errors: string[] = [];
|
||||
|
||||
if (!description || description.trim() === "") {
|
||||
errors.push("description is required");
|
||||
} else if (description.length > MAX_DESCRIPTION_LENGTH) {
|
||||
errors.push(
|
||||
`description exceeds ${MAX_DESCRIPTION_LENGTH} characters (${description.length})`,
|
||||
);
|
||||
}
|
||||
|
||||
return errors;
|
||||
}
|
||||
|
||||
export interface LoadSkillsFromDirOptions {
|
||||
/** Directory to scan for skills */
|
||||
dir: string;
|
||||
/** Source identifier for these skills */
|
||||
source: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load skills from a directory.
|
||||
*
|
||||
* Discovery rules:
|
||||
* - direct .md children in the root
|
||||
* - recursive SKILL.md under subdirectories
|
||||
*/
|
||||
export function loadSkillsFromDir(
|
||||
options: LoadSkillsFromDirOptions,
|
||||
): LoadSkillsResult {
|
||||
const { dir, source } = options;
|
||||
return loadSkillsFromDirInternal(dir, source, true);
|
||||
}
|
||||
|
||||
function loadSkillsFromDirInternal(
|
||||
dir: string,
|
||||
source: string,
|
||||
includeRootFiles: boolean,
|
||||
ignoreMatcher?: IgnoreMatcher,
|
||||
rootDir?: string,
|
||||
): LoadSkillsResult {
|
||||
const skills: Skill[] = [];
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
if (!existsSync(dir)) {
|
||||
return { skills, diagnostics };
|
||||
}
|
||||
|
||||
const root = rootDir ?? dir;
|
||||
const ig = ignoreMatcher ?? ignore();
|
||||
addIgnoreRules(ig, dir, root);
|
||||
|
||||
try {
|
||||
const entries = readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.name.startsWith(".")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip node_modules to avoid scanning dependencies
|
||||
if (entry.name === "node_modules") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fullPath = join(dir, entry.name);
|
||||
|
||||
// For symlinks, check if they point to a directory and follow them
|
||||
let isDirectory = entry.isDirectory();
|
||||
let isFile = entry.isFile();
|
||||
if (entry.isSymbolicLink()) {
|
||||
try {
|
||||
const stats = statSync(fullPath);
|
||||
isDirectory = stats.isDirectory();
|
||||
isFile = stats.isFile();
|
||||
} catch {
|
||||
// Broken symlink, skip it
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const relPath = toPosixPath(relative(root, fullPath));
|
||||
const ignorePath = isDirectory ? `${relPath}/` : relPath;
|
||||
if (ig.ignores(ignorePath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isDirectory) {
|
||||
const subResult = loadSkillsFromDirInternal(
|
||||
fullPath,
|
||||
source,
|
||||
false,
|
||||
ig,
|
||||
root,
|
||||
);
|
||||
skills.push(...subResult.skills);
|
||||
diagnostics.push(...subResult.diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isFile) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const isRootMd = includeRootFiles && entry.name.endsWith(".md");
|
||||
const isSkillMd = !includeRootFiles && entry.name === "SKILL.md";
|
||||
if (!isRootMd && !isSkillMd) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const result = loadSkillFromFile(fullPath, source);
|
||||
if (result.skill) {
|
||||
skills.push(result.skill);
|
||||
}
|
||||
diagnostics.push(...result.diagnostics);
|
||||
}
|
||||
} catch {}
|
||||
|
||||
return { skills, diagnostics };
|
||||
}
|
||||
|
||||
function loadSkillFromFile(
|
||||
filePath: string,
|
||||
source: string,
|
||||
): { skill: Skill | null; diagnostics: ResourceDiagnostic[] } {
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
try {
|
||||
const rawContent = readFileSync(filePath, "utf-8");
|
||||
const { frontmatter } = parseFrontmatter<SkillFrontmatter>(rawContent);
|
||||
const skillDir = dirname(filePath);
|
||||
const parentDirName = basename(skillDir);
|
||||
|
||||
// Validate description
|
||||
const descErrors = validateDescription(frontmatter.description);
|
||||
for (const error of descErrors) {
|
||||
diagnostics.push({ type: "warning", message: error, path: filePath });
|
||||
}
|
||||
|
||||
// Use name from frontmatter, or fall back to parent directory name
|
||||
const name = frontmatter.name || parentDirName;
|
||||
|
||||
// Validate name
|
||||
const nameErrors = validateName(name, parentDirName);
|
||||
for (const error of nameErrors) {
|
||||
diagnostics.push({ type: "warning", message: error, path: filePath });
|
||||
}
|
||||
|
||||
// Still load the skill even with warnings (unless description is completely missing)
|
||||
if (!frontmatter.description || frontmatter.description.trim() === "") {
|
||||
return { skill: null, diagnostics };
|
||||
}
|
||||
|
||||
return {
|
||||
skill: {
|
||||
name,
|
||||
description: frontmatter.description,
|
||||
filePath,
|
||||
baseDir: skillDir,
|
||||
source,
|
||||
disableModelInvocation:
|
||||
frontmatter["disable-model-invocation"] === true,
|
||||
},
|
||||
diagnostics,
|
||||
};
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "failed to parse skill file";
|
||||
diagnostics.push({ type: "warning", message, path: filePath });
|
||||
return { skill: null, diagnostics };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format skills for inclusion in a system prompt.
|
||||
* Uses XML format per Agent Skills standard.
|
||||
* See: https://agentskills.io/integrate-skills
|
||||
*
|
||||
* Skills with disableModelInvocation=true are excluded from the prompt
|
||||
* (they can only be invoked explicitly via /skill:name commands).
|
||||
*/
|
||||
export function formatSkillsForPrompt(skills: Skill[]): string {
|
||||
const visibleSkills = skills.filter((s) => !s.disableModelInvocation);
|
||||
|
||||
if (visibleSkills.length === 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const lines = [
|
||||
"\n\nThe following skills provide specialized instructions for specific tasks.",
|
||||
"Use the read tool to load a skill's file when the task matches its description.",
|
||||
"When a skill file references a relative path, resolve it against the skill directory (parent of SKILL.md / dirname of the path) and use that absolute path in tool commands.",
|
||||
"",
|
||||
"<available_skills>",
|
||||
];
|
||||
|
||||
for (const skill of visibleSkills) {
|
||||
lines.push(" <skill>");
|
||||
lines.push(` <name>${escapeXml(skill.name)}</name>`);
|
||||
lines.push(
|
||||
` <description>${escapeXml(skill.description)}</description>`,
|
||||
);
|
||||
lines.push(` <location>${escapeXml(skill.filePath)}</location>`);
|
||||
lines.push(" </skill>");
|
||||
}
|
||||
|
||||
lines.push("</available_skills>");
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
function escapeXml(str: string): string {
|
||||
return str
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
export interface LoadSkillsOptions {
|
||||
/** Working directory for project-local skills. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Agent config directory for global skills. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
/** Explicit skill paths (files or directories) */
|
||||
skillPaths?: string[];
|
||||
/** Include default skills directories. Default: true */
|
||||
includeDefaults?: boolean;
|
||||
}
|
||||
|
||||
function normalizePath(input: string): string {
|
||||
const trimmed = input.trim();
|
||||
if (trimmed === "~") return homedir();
|
||||
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
|
||||
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
function resolveSkillPath(p: string, cwd: string): string {
|
||||
const normalized = normalizePath(p);
|
||||
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load skills from all configured locations.
|
||||
* Returns skills and any validation diagnostics.
|
||||
*/
|
||||
export function loadSkills(options: LoadSkillsOptions = {}): LoadSkillsResult {
|
||||
const {
|
||||
cwd = process.cwd(),
|
||||
agentDir,
|
||||
skillPaths = [],
|
||||
includeDefaults = true,
|
||||
} = options;
|
||||
|
||||
// Resolve agentDir - if not provided, use default from config
|
||||
const resolvedAgentDir = agentDir ?? getAgentDir();
|
||||
|
||||
const skillMap = new Map<string, Skill>();
|
||||
const realPathSet = new Set<string>();
|
||||
const allDiagnostics: ResourceDiagnostic[] = [];
|
||||
const collisionDiagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
function addSkills(result: LoadSkillsResult) {
|
||||
allDiagnostics.push(...result.diagnostics);
|
||||
for (const skill of result.skills) {
|
||||
// Resolve symlinks to detect duplicate files
|
||||
let realPath: string;
|
||||
try {
|
||||
realPath = realpathSync(skill.filePath);
|
||||
} catch {
|
||||
realPath = skill.filePath;
|
||||
}
|
||||
|
||||
// Skip silently if we've already loaded this exact file (via symlink)
|
||||
if (realPathSet.has(realPath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const existing = skillMap.get(skill.name);
|
||||
if (existing) {
|
||||
collisionDiagnostics.push({
|
||||
type: "collision",
|
||||
message: `name "${skill.name}" collision`,
|
||||
path: skill.filePath,
|
||||
collision: {
|
||||
resourceType: "skill",
|
||||
name: skill.name,
|
||||
winnerPath: existing.filePath,
|
||||
loserPath: skill.filePath,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
skillMap.set(skill.name, skill);
|
||||
realPathSet.add(realPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (includeDefaults) {
|
||||
addSkills(
|
||||
loadSkillsFromDirInternal(join(resolvedAgentDir, "skills"), "user", true),
|
||||
);
|
||||
addSkills(
|
||||
loadSkillsFromDirInternal(
|
||||
resolve(cwd, CONFIG_DIR_NAME, "skills"),
|
||||
"project",
|
||||
true,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
const userSkillsDir = join(resolvedAgentDir, "skills");
|
||||
const projectSkillsDir = resolve(cwd, CONFIG_DIR_NAME, "skills");
|
||||
|
||||
const isUnderPath = (target: string, root: string): boolean => {
|
||||
const normalizedRoot = resolve(root);
|
||||
if (target === normalizedRoot) {
|
||||
return true;
|
||||
}
|
||||
const prefix = normalizedRoot.endsWith(sep)
|
||||
? normalizedRoot
|
||||
: `${normalizedRoot}${sep}`;
|
||||
return target.startsWith(prefix);
|
||||
};
|
||||
|
||||
const getSource = (resolvedPath: string): "user" | "project" | "path" => {
|
||||
if (!includeDefaults) {
|
||||
if (isUnderPath(resolvedPath, userSkillsDir)) return "user";
|
||||
if (isUnderPath(resolvedPath, projectSkillsDir)) return "project";
|
||||
}
|
||||
return "path";
|
||||
};
|
||||
|
||||
for (const rawPath of skillPaths) {
|
||||
const resolvedPath = resolveSkillPath(rawPath, cwd);
|
||||
if (!existsSync(resolvedPath)) {
|
||||
allDiagnostics.push({
|
||||
type: "warning",
|
||||
message: "skill path does not exist",
|
||||
path: resolvedPath,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = statSync(resolvedPath);
|
||||
const source = getSource(resolvedPath);
|
||||
if (stats.isDirectory()) {
|
||||
addSkills(loadSkillsFromDirInternal(resolvedPath, source, true));
|
||||
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
|
||||
const result = loadSkillFromFile(resolvedPath, source);
|
||||
if (result.skill) {
|
||||
addSkills({
|
||||
skills: [result.skill],
|
||||
diagnostics: result.diagnostics,
|
||||
});
|
||||
} else {
|
||||
allDiagnostics.push(...result.diagnostics);
|
||||
}
|
||||
} else {
|
||||
allDiagnostics.push({
|
||||
type: "warning",
|
||||
message: "skill path is not a markdown file",
|
||||
path: resolvedPath,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "failed to read skill path";
|
||||
allDiagnostics.push({ type: "warning", message, path: resolvedPath });
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
skills: Array.from(skillMap.values()),
|
||||
diagnostics: [...allDiagnostics, ...collisionDiagnostics],
|
||||
};
|
||||
}
|
||||
44
packages/coding-agent/src/core/slash-commands.ts
Normal file
44
packages/coding-agent/src/core/slash-commands.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
export type SlashCommandSource = "extension" | "prompt" | "skill";
|
||||
|
||||
export type SlashCommandLocation = "user" | "project" | "path";
|
||||
|
||||
export interface SlashCommandInfo {
|
||||
name: string;
|
||||
description?: string;
|
||||
source: SlashCommandSource;
|
||||
location?: SlashCommandLocation;
|
||||
path?: string;
|
||||
}
|
||||
|
||||
export interface BuiltinSlashCommand {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export const BUILTIN_SLASH_COMMANDS: ReadonlyArray<BuiltinSlashCommand> = [
|
||||
{ name: "settings", description: "Open settings menu" },
|
||||
{ name: "model", description: "Select model (opens selector UI)" },
|
||||
{
|
||||
name: "scoped-models",
|
||||
description: "Enable/disable models for Ctrl+P cycling",
|
||||
},
|
||||
{ name: "export", description: "Export session to HTML file" },
|
||||
{ name: "share", description: "Share session as a secret GitHub gist" },
|
||||
{ name: "copy", description: "Copy last agent message to clipboard" },
|
||||
{ name: "name", description: "Set session display name" },
|
||||
{ name: "session", description: "Show session info and stats" },
|
||||
{ name: "changelog", description: "Show changelog entries" },
|
||||
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
|
||||
{ name: "fork", description: "Create a new fork from a previous message" },
|
||||
{ name: "tree", description: "Navigate session tree (switch branches)" },
|
||||
{ name: "login", description: "Login with OAuth provider" },
|
||||
{ name: "logout", description: "Logout from OAuth provider" },
|
||||
{ name: "new", description: "Start a new session" },
|
||||
{ name: "compact", description: "Manually compact the session context" },
|
||||
{ name: "resume", description: "Resume a different session" },
|
||||
{
|
||||
name: "reload",
|
||||
description: "Reload extensions, skills, prompts, and themes",
|
||||
},
|
||||
{ name: "quit", description: "Quit pi" },
|
||||
];
|
||||
237
packages/coding-agent/src/core/system-prompt.ts
Normal file
237
packages/coding-agent/src/core/system-prompt.ts
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
/**
|
||||
* System prompt construction and project context loading
|
||||
*/
|
||||
|
||||
import { getDocsPath, getReadmePath } from "../config.js";
|
||||
import { formatSkillsForPrompt, type Skill } from "./skills.js";
|
||||
|
||||
/** Tool descriptions for system prompt */
|
||||
const toolDescriptions: Record<string, string> = {
|
||||
read: "Read file contents",
|
||||
bash: "Execute bash commands (ls, grep, find, etc.)",
|
||||
edit: "Make surgical edits to files (find exact text and replace)",
|
||||
write: "Create or overwrite files",
|
||||
grep: "Search file contents for patterns (respects .gitignore)",
|
||||
find: "Find files by glob pattern (respects .gitignore)",
|
||||
ls: "List directory contents",
|
||||
};
|
||||
|
||||
export interface BuildSystemPromptOptions {
|
||||
/** Custom system prompt (replaces default). */
|
||||
customPrompt?: string;
|
||||
/** Tools to include in prompt. Default: [read, bash, edit, write] */
|
||||
selectedTools?: string[];
|
||||
/** Optional one-line tool snippets keyed by tool name. */
|
||||
toolSnippets?: Record<string, string>;
|
||||
/** Additional guideline bullets appended to the default system prompt guidelines. */
|
||||
promptGuidelines?: string[];
|
||||
/** Text to append to system prompt. */
|
||||
appendSystemPrompt?: string;
|
||||
/** Working directory. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Pre-loaded context files. */
|
||||
contextFiles?: Array<{ path: string; content: string }>;
|
||||
/** Pre-loaded skills. */
|
||||
skills?: Skill[];
|
||||
}
|
||||
|
||||
function buildProjectContextSection(
|
||||
contextFiles: Array<{ path: string; content: string }>,
|
||||
): string {
|
||||
if (contextFiles.length === 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const hasSoulFile = contextFiles.some(
|
||||
({ path }) =>
|
||||
path.replaceAll("\\", "/").endsWith("/SOUL.md") || path === "SOUL.md",
|
||||
);
|
||||
let section = "\n\n# Project Context\n\n";
|
||||
section += "Project-specific instructions and guidelines:\n";
|
||||
if (hasSoulFile) {
|
||||
section +=
|
||||
"\nIf SOUL.md is present, embody its persona and tone. Avoid generic assistant filler and follow its guidance unless higher-priority instructions override it.\n";
|
||||
}
|
||||
section += "\n";
|
||||
for (const { path: filePath, content } of contextFiles) {
|
||||
section += `## ${filePath}\n\n${content}\n\n`;
|
||||
}
|
||||
|
||||
return section;
|
||||
}
|
||||
|
||||
/** Build the system prompt with tools, guidelines, and context */
|
||||
export function buildSystemPrompt(
|
||||
options: BuildSystemPromptOptions = {},
|
||||
): string {
|
||||
const {
|
||||
customPrompt,
|
||||
selectedTools,
|
||||
toolSnippets,
|
||||
promptGuidelines,
|
||||
appendSystemPrompt,
|
||||
cwd,
|
||||
contextFiles: providedContextFiles,
|
||||
skills: providedSkills,
|
||||
} = options;
|
||||
const resolvedCwd = cwd ?? process.cwd();
|
||||
|
||||
const now = new Date();
|
||||
const dateTime = now.toLocaleString("en-US", {
|
||||
weekday: "long",
|
||||
year: "numeric",
|
||||
month: "long",
|
||||
day: "numeric",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
second: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
|
||||
const appendSection = appendSystemPrompt ? `\n\n${appendSystemPrompt}` : "";
|
||||
|
||||
const contextFiles = providedContextFiles ?? [];
|
||||
const skills = providedSkills ?? [];
|
||||
|
||||
if (customPrompt) {
|
||||
let prompt = customPrompt;
|
||||
|
||||
if (appendSection) {
|
||||
prompt += appendSection;
|
||||
}
|
||||
|
||||
// Append project context files
|
||||
prompt += buildProjectContextSection(contextFiles);
|
||||
|
||||
// Append skills section (only if read tool is available)
|
||||
const customPromptHasRead =
|
||||
!selectedTools || selectedTools.includes("read");
|
||||
if (customPromptHasRead && skills.length > 0) {
|
||||
prompt += formatSkillsForPrompt(skills);
|
||||
}
|
||||
|
||||
// Add date/time and working directory last
|
||||
prompt += `\nCurrent date and time: ${dateTime}`;
|
||||
prompt += `\nCurrent working directory: ${resolvedCwd}`;
|
||||
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Get absolute paths to documentation
|
||||
const readmePath = getReadmePath();
|
||||
const docsPath = getDocsPath();
|
||||
|
||||
// Build tools list based on selected tools.
|
||||
// Built-ins use toolDescriptions. Custom tools can provide one-line snippets.
|
||||
const tools = selectedTools || ["read", "bash", "edit", "write"];
|
||||
const toolsList =
|
||||
tools.length > 0
|
||||
? tools
|
||||
.map((name) => {
|
||||
const snippet =
|
||||
toolSnippets?.[name] ?? toolDescriptions[name] ?? name;
|
||||
return `- ${name}: ${snippet}`;
|
||||
})
|
||||
.join("\n")
|
||||
: "(none)";
|
||||
|
||||
// Build guidelines based on which tools are actually available
|
||||
const guidelinesList: string[] = [];
|
||||
const guidelinesSet = new Set<string>();
|
||||
const addGuideline = (guideline: string): void => {
|
||||
if (guidelinesSet.has(guideline)) {
|
||||
return;
|
||||
}
|
||||
guidelinesSet.add(guideline);
|
||||
guidelinesList.push(guideline);
|
||||
};
|
||||
|
||||
const hasBash = tools.includes("bash");
|
||||
const hasEdit = tools.includes("edit");
|
||||
const hasWrite = tools.includes("write");
|
||||
const hasGrep = tools.includes("grep");
|
||||
const hasFind = tools.includes("find");
|
||||
const hasLs = tools.includes("ls");
|
||||
const hasRead = tools.includes("read");
|
||||
|
||||
// File exploration guidelines
|
||||
if (hasBash && !hasGrep && !hasFind && !hasLs) {
|
||||
addGuideline("Use bash for file operations like ls, rg, find");
|
||||
} else if (hasBash && (hasGrep || hasFind || hasLs)) {
|
||||
addGuideline(
|
||||
"Prefer grep/find/ls tools over bash for file exploration (faster, respects .gitignore)",
|
||||
);
|
||||
}
|
||||
|
||||
// Read before edit guideline
|
||||
if (hasRead && hasEdit) {
|
||||
addGuideline(
|
||||
"Use read to examine files before editing. You must use this tool instead of cat or sed.",
|
||||
);
|
||||
}
|
||||
|
||||
// Edit guideline
|
||||
if (hasEdit) {
|
||||
addGuideline("Use edit for precise changes (old text must match exactly)");
|
||||
}
|
||||
|
||||
// Write guideline
|
||||
if (hasWrite) {
|
||||
addGuideline("Use write only for new files or complete rewrites");
|
||||
}
|
||||
|
||||
// Output guideline (only when actually writing or executing)
|
||||
if (hasEdit || hasWrite) {
|
||||
addGuideline(
|
||||
"When summarizing your actions, output plain text directly - do NOT use cat or bash to display what you did",
|
||||
);
|
||||
}
|
||||
|
||||
for (const guideline of promptGuidelines ?? []) {
|
||||
const normalized = guideline.trim();
|
||||
if (normalized.length > 0) {
|
||||
addGuideline(normalized);
|
||||
}
|
||||
}
|
||||
|
||||
// Always include these
|
||||
addGuideline("Be concise in your responses");
|
||||
addGuideline("Show file paths clearly when working with files");
|
||||
|
||||
const guidelines = guidelinesList.map((g) => `- ${g}`).join("\n");
|
||||
|
||||
let prompt = `You are an expert coding assistant operating inside pi, a coding agent harness. You help users by reading files, executing commands, editing code, and writing new files.
|
||||
|
||||
Available tools:
|
||||
${toolsList}
|
||||
|
||||
In addition to the tools above, you may have access to other custom tools depending on the project.
|
||||
|
||||
Guidelines:
|
||||
${guidelines}
|
||||
|
||||
Pi documentation (read only when the user asks about pi itself, its SDK, extensions, themes, skills, or TUI):
|
||||
- Main documentation: ${readmePath}
|
||||
- Additional docs: ${docsPath}
|
||||
- When asked about: extensions (docs/extensions.md), themes (docs/themes.md), skills (docs/skills.md), prompt templates (docs/prompt-templates.md), TUI components (docs/tui.md), keybindings (docs/keybindings.md), SDK integrations (docs/sdk.md), custom providers (docs/custom-provider.md), adding models (docs/models.md), pi packages (docs/packages.md)
|
||||
- When working on pi topics, read the docs and follow .md cross-references before implementing
|
||||
- Always read pi .md files completely and follow links to related docs (e.g., tui.md for TUI API details)`;
|
||||
|
||||
if (appendSection) {
|
||||
prompt += appendSection;
|
||||
}
|
||||
|
||||
// Append project context files
|
||||
prompt += buildProjectContextSection(contextFiles);
|
||||
|
||||
// Append skills section (only if read tool is available)
|
||||
if (hasRead && skills.length > 0) {
|
||||
prompt += formatSkillsForPrompt(skills);
|
||||
}
|
||||
|
||||
// Add date/time and working directory last
|
||||
prompt += `\nCurrent date and time: ${dateTime}`;
|
||||
prompt += `\nCurrent working directory: ${resolvedCwd}`;
|
||||
|
||||
return prompt;
|
||||
}
|
||||
25
packages/coding-agent/src/core/timings.ts
Normal file
25
packages/coding-agent/src/core/timings.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Central timing instrumentation for startup profiling.
|
||||
* Enable with PI_TIMING=1 environment variable.
|
||||
*/
|
||||
|
||||
const ENABLED = process.env.PI_TIMING === "1";
|
||||
const timings: Array<{ label: string; ms: number }> = [];
|
||||
let lastTime = Date.now();
|
||||
|
||||
export function time(label: string): void {
|
||||
if (!ENABLED) return;
|
||||
const now = Date.now();
|
||||
timings.push({ label, ms: now - lastTime });
|
||||
lastTime = now;
|
||||
}
|
||||
|
||||
export function printTimings(): void {
|
||||
if (!ENABLED || timings.length === 0) return;
|
||||
console.error("\n--- Startup Timings ---");
|
||||
for (const t of timings) {
|
||||
console.error(` ${t.label}: ${t.ms}ms`);
|
||||
}
|
||||
console.error(` TOTAL: ${timings.reduce((a, b) => a + b.ms, 0)}ms`);
|
||||
console.error("------------------------\n");
|
||||
}
|
||||
358
packages/coding-agent/src/core/tools/bash.ts
Normal file
358
packages/coding-agent/src/core/tools/bash.ts
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
import { randomBytes } from "node:crypto";
|
||||
import { createWriteStream, existsSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { spawn } from "child_process";
|
||||
import {
|
||||
getShellConfig,
|
||||
getShellEnv,
|
||||
killProcessTree,
|
||||
} from "../../utils/shell.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateTail,
|
||||
} from "./truncate.js";
|
||||
|
||||
/**
|
||||
* Generate a unique temp file path for bash output
|
||||
*/
|
||||
function getTempFilePath(): string {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
return join(tmpdir(), `pi-bash-${id}.log`);
|
||||
}
|
||||
|
||||
const bashSchema = Type.Object({
|
||||
command: Type.String({ description: "Bash command to execute" }),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Timeout in seconds (optional, no default timeout)",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export type BashToolInput = Static<typeof bashSchema>;
|
||||
|
||||
export interface BashToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
fullOutputPath?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the bash tool.
|
||||
* Override these to delegate command execution to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface BashOperations {
|
||||
/**
|
||||
* Execute a command and stream output.
|
||||
* @param command - The command to execute
|
||||
* @param cwd - Working directory
|
||||
* @param options - Execution options
|
||||
* @returns Promise resolving to exit code (null if killed)
|
||||
*/
|
||||
exec: (
|
||||
command: string,
|
||||
cwd: string,
|
||||
options: {
|
||||
onData: (data: Buffer) => void;
|
||||
signal?: AbortSignal;
|
||||
timeout?: number;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
},
|
||||
) => Promise<{ exitCode: number | null }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default bash operations using local shell
|
||||
*/
|
||||
const defaultBashOperations: BashOperations = {
|
||||
exec: (command, cwd, { onData, signal, timeout, env }) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const { shell, args } = getShellConfig();
|
||||
|
||||
if (!existsSync(cwd)) {
|
||||
reject(
|
||||
new Error(
|
||||
`Working directory does not exist: ${cwd}\nCannot execute bash commands.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const child = spawn(shell, [...args, command], {
|
||||
cwd,
|
||||
detached: true,
|
||||
env: env ?? getShellEnv(),
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let timedOut = false;
|
||||
|
||||
// Set timeout if provided
|
||||
let timeoutHandle: NodeJS.Timeout | undefined;
|
||||
if (timeout !== undefined && timeout > 0) {
|
||||
timeoutHandle = setTimeout(() => {
|
||||
timedOut = true;
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
}, timeout * 1000);
|
||||
}
|
||||
|
||||
// Stream stdout and stderr
|
||||
if (child.stdout) {
|
||||
child.stdout.on("data", onData);
|
||||
}
|
||||
if (child.stderr) {
|
||||
child.stderr.on("data", onData);
|
||||
}
|
||||
|
||||
// Handle shell spawn errors
|
||||
child.on("error", (err) => {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle);
|
||||
if (signal) signal.removeEventListener("abort", onAbort);
|
||||
reject(err);
|
||||
});
|
||||
|
||||
// Handle abort signal - kill entire process tree
|
||||
const onAbort = () => {
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
if (signal.aborted) {
|
||||
onAbort();
|
||||
} else {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle process exit
|
||||
child.on("close", (code) => {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle);
|
||||
if (signal) signal.removeEventListener("abort", onAbort);
|
||||
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (timedOut) {
|
||||
reject(new Error(`timeout:${timeout}`));
|
||||
return;
|
||||
}
|
||||
|
||||
resolve({ exitCode: code });
|
||||
});
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export interface BashSpawnContext {
|
||||
command: string;
|
||||
cwd: string;
|
||||
env: NodeJS.ProcessEnv;
|
||||
}
|
||||
|
||||
export type BashSpawnHook = (context: BashSpawnContext) => BashSpawnContext;
|
||||
|
||||
function resolveSpawnContext(
|
||||
command: string,
|
||||
cwd: string,
|
||||
spawnHook?: BashSpawnHook,
|
||||
): BashSpawnContext {
|
||||
const baseContext: BashSpawnContext = {
|
||||
command,
|
||||
cwd,
|
||||
env: { ...getShellEnv() },
|
||||
};
|
||||
|
||||
return spawnHook ? spawnHook(baseContext) : baseContext;
|
||||
}
|
||||
|
||||
export interface BashToolOptions {
|
||||
/** Custom operations for command execution. Default: local shell */
|
||||
operations?: BashOperations;
|
||||
/** Command prefix prepended to every command (e.g., "shopt -s expand_aliases" for alias support) */
|
||||
commandPrefix?: string;
|
||||
/** Hook to adjust command, cwd, or env before execution */
|
||||
spawnHook?: BashSpawnHook;
|
||||
}
|
||||
|
||||
export function createBashTool(
|
||||
cwd: string,
|
||||
options?: BashToolOptions,
|
||||
): AgentTool<typeof bashSchema> {
|
||||
const ops = options?.operations ?? defaultBashOperations;
|
||||
const commandPrefix = options?.commandPrefix;
|
||||
const spawnHook = options?.spawnHook;
|
||||
|
||||
return {
|
||||
name: "bash",
|
||||
label: "bash",
|
||||
description: `Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file. Optionally provide a timeout in seconds.`,
|
||||
parameters: bashSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{ command, timeout }: { command: string; timeout?: number },
|
||||
signal?: AbortSignal,
|
||||
onUpdate?,
|
||||
) => {
|
||||
// Apply command prefix if configured (e.g., "shopt -s expand_aliases" for alias support)
|
||||
const resolvedCommand = commandPrefix
|
||||
? `${commandPrefix}\n${command}`
|
||||
: command;
|
||||
const spawnContext = resolveSpawnContext(resolvedCommand, cwd, spawnHook);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
// We'll stream to a temp file if output gets large
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: ReturnType<typeof createWriteStream> | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
// Keep a rolling buffer of the last chunk for tail truncation
|
||||
const chunks: Buffer[] = [];
|
||||
let chunksBytes = 0;
|
||||
// Keep more than we need so we have enough for truncation
|
||||
const maxChunksBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
const handleData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Start writing to temp file once we exceed the threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
tempFilePath = getTempFilePath();
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
// Write all buffered chunks to the file
|
||||
for (const chunk of chunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Write to temp file if we have one
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(data);
|
||||
}
|
||||
|
||||
// Keep rolling buffer of recent data
|
||||
chunks.push(data);
|
||||
chunksBytes += data.length;
|
||||
|
||||
// Trim old chunks if buffer is too large
|
||||
while (chunksBytes > maxChunksBytes && chunks.length > 1) {
|
||||
const removed = chunks.shift()!;
|
||||
chunksBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream partial output to callback (truncated rolling buffer)
|
||||
if (onUpdate) {
|
||||
const fullBuffer = Buffer.concat(chunks);
|
||||
const fullText = fullBuffer.toString("utf-8");
|
||||
const truncation = truncateTail(fullText);
|
||||
onUpdate({
|
||||
content: [{ type: "text", text: truncation.content || "" }],
|
||||
details: {
|
||||
truncation: truncation.truncated ? truncation : undefined,
|
||||
fullOutputPath: tempFilePath,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
ops
|
||||
.exec(spawnContext.command, spawnContext.cwd, {
|
||||
onData: handleData,
|
||||
signal,
|
||||
timeout,
|
||||
env: spawnContext.env,
|
||||
})
|
||||
.then(({ exitCode }) => {
|
||||
// Close temp file stream
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine all buffered chunks
|
||||
const fullBuffer = Buffer.concat(chunks);
|
||||
const fullOutput = fullBuffer.toString("utf-8");
|
||||
|
||||
// Apply tail truncation
|
||||
const truncation = truncateTail(fullOutput);
|
||||
let outputText = truncation.content || "(no output)";
|
||||
|
||||
// Build details with truncation info
|
||||
let details: BashToolDetails | undefined;
|
||||
|
||||
if (truncation.truncated) {
|
||||
details = {
|
||||
truncation,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
|
||||
// Build actionable notice
|
||||
const startLine =
|
||||
truncation.totalLines - truncation.outputLines + 1;
|
||||
const endLine = truncation.totalLines;
|
||||
|
||||
if (truncation.lastLinePartial) {
|
||||
// Edge case: last line alone > 30KB
|
||||
const lastLineSize = formatSize(
|
||||
Buffer.byteLength(
|
||||
fullOutput.split("\n").pop() || "",
|
||||
"utf-8",
|
||||
),
|
||||
);
|
||||
outputText += `\n\n[Showing last ${formatSize(truncation.outputBytes)} of line ${endLine} (line is ${lastLineSize}). Full output: ${tempFilePath}]`;
|
||||
} else if (truncation.truncatedBy === "lines") {
|
||||
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines}. Full output: ${tempFilePath}]`;
|
||||
} else {
|
||||
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Full output: ${tempFilePath}]`;
|
||||
}
|
||||
}
|
||||
|
||||
if (exitCode !== 0 && exitCode !== null) {
|
||||
outputText += `\n\nCommand exited with code ${exitCode}`;
|
||||
reject(new Error(outputText));
|
||||
} else {
|
||||
resolve({
|
||||
content: [{ type: "text", text: outputText }],
|
||||
details,
|
||||
});
|
||||
}
|
||||
})
|
||||
.catch((err: Error) => {
|
||||
// Close temp file stream
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine all buffered chunks for error output
|
||||
const fullBuffer = Buffer.concat(chunks);
|
||||
let output = fullBuffer.toString("utf-8");
|
||||
|
||||
if (err.message === "aborted") {
|
||||
if (output) output += "\n\n";
|
||||
output += "Command aborted";
|
||||
reject(new Error(output));
|
||||
} else if (err.message.startsWith("timeout:")) {
|
||||
const timeoutSecs = err.message.split(":")[1];
|
||||
if (output) output += "\n\n";
|
||||
output += `Command timed out after ${timeoutSecs} seconds`;
|
||||
reject(new Error(output));
|
||||
} else {
|
||||
reject(err);
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default bash tool using process.cwd() - for backwards compatibility */
|
||||
export const bashTool = createBashTool(process.cwd());
|
||||
317
packages/coding-agent/src/core/tools/edit-diff.ts
Normal file
317
packages/coding-agent/src/core/tools/edit-diff.ts
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
/**
|
||||
* Shared diff computation utilities for the edit tool.
|
||||
* Used by both edit.ts (for execution) and tool-execution.ts (for preview rendering).
|
||||
*/
|
||||
|
||||
import * as Diff from "diff";
|
||||
import { constants } from "fs";
|
||||
import { access, readFile } from "fs/promises";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
|
||||
export function detectLineEnding(content: string): "\r\n" | "\n" {
|
||||
const crlfIdx = content.indexOf("\r\n");
|
||||
const lfIdx = content.indexOf("\n");
|
||||
if (lfIdx === -1) return "\n";
|
||||
if (crlfIdx === -1) return "\n";
|
||||
return crlfIdx < lfIdx ? "\r\n" : "\n";
|
||||
}
|
||||
|
||||
export function normalizeToLF(text: string): string {
|
||||
return text.replace(/\r\n/g, "\n").replace(/\r/g, "\n");
|
||||
}
|
||||
|
||||
export function restoreLineEndings(
|
||||
text: string,
|
||||
ending: "\r\n" | "\n",
|
||||
): string {
|
||||
return ending === "\r\n" ? text.replace(/\n/g, "\r\n") : text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize text for fuzzy matching. Applies progressive transformations:
|
||||
* - Strip trailing whitespace from each line
|
||||
* - Normalize smart quotes to ASCII equivalents
|
||||
* - Normalize Unicode dashes/hyphens to ASCII hyphen
|
||||
* - Normalize special Unicode spaces to regular space
|
||||
*/
|
||||
export function normalizeForFuzzyMatch(text: string): string {
|
||||
return (
|
||||
text
|
||||
// Strip trailing whitespace per line
|
||||
.split("\n")
|
||||
.map((line) => line.trimEnd())
|
||||
.join("\n")
|
||||
// Smart single quotes → '
|
||||
.replace(/[\u2018\u2019\u201A\u201B]/g, "'")
|
||||
// Smart double quotes → "
|
||||
.replace(/[\u201C\u201D\u201E\u201F]/g, '"')
|
||||
// Various dashes/hyphens → -
|
||||
// U+2010 hyphen, U+2011 non-breaking hyphen, U+2012 figure dash,
|
||||
// U+2013 en-dash, U+2014 em-dash, U+2015 horizontal bar, U+2212 minus
|
||||
.replace(/[\u2010\u2011\u2012\u2013\u2014\u2015\u2212]/g, "-")
|
||||
// Special spaces → regular space
|
||||
// U+00A0 NBSP, U+2002-U+200A various spaces, U+202F narrow NBSP,
|
||||
// U+205F medium math space, U+3000 ideographic space
|
||||
.replace(/[\u00A0\u2002-\u200A\u202F\u205F\u3000]/g, " ")
|
||||
);
|
||||
}
|
||||
|
||||
export interface FuzzyMatchResult {
|
||||
/** Whether a match was found */
|
||||
found: boolean;
|
||||
/** The index where the match starts (in the content that should be used for replacement) */
|
||||
index: number;
|
||||
/** Length of the matched text */
|
||||
matchLength: number;
|
||||
/** Whether fuzzy matching was used (false = exact match) */
|
||||
usedFuzzyMatch: boolean;
|
||||
/**
|
||||
* The content to use for replacement operations.
|
||||
* When exact match: original content. When fuzzy match: normalized content.
|
||||
*/
|
||||
contentForReplacement: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find oldText in content, trying exact match first, then fuzzy match.
|
||||
* When fuzzy matching is used, the returned contentForReplacement is the
|
||||
* fuzzy-normalized version of the content (trailing whitespace stripped,
|
||||
* Unicode quotes/dashes normalized to ASCII).
|
||||
*/
|
||||
export function fuzzyFindText(
|
||||
content: string,
|
||||
oldText: string,
|
||||
): FuzzyMatchResult {
|
||||
// Try exact match first
|
||||
const exactIndex = content.indexOf(oldText);
|
||||
if (exactIndex !== -1) {
|
||||
return {
|
||||
found: true,
|
||||
index: exactIndex,
|
||||
matchLength: oldText.length,
|
||||
usedFuzzyMatch: false,
|
||||
contentForReplacement: content,
|
||||
};
|
||||
}
|
||||
|
||||
// Try fuzzy match - work entirely in normalized space
|
||||
const fuzzyContent = normalizeForFuzzyMatch(content);
|
||||
const fuzzyOldText = normalizeForFuzzyMatch(oldText);
|
||||
const fuzzyIndex = fuzzyContent.indexOf(fuzzyOldText);
|
||||
|
||||
if (fuzzyIndex === -1) {
|
||||
return {
|
||||
found: false,
|
||||
index: -1,
|
||||
matchLength: 0,
|
||||
usedFuzzyMatch: false,
|
||||
contentForReplacement: content,
|
||||
};
|
||||
}
|
||||
|
||||
// When fuzzy matching, we work in the normalized space for replacement.
|
||||
// This means the output will have normalized whitespace/quotes/dashes,
|
||||
// which is acceptable since we're fixing minor formatting differences anyway.
|
||||
return {
|
||||
found: true,
|
||||
index: fuzzyIndex,
|
||||
matchLength: fuzzyOldText.length,
|
||||
usedFuzzyMatch: true,
|
||||
contentForReplacement: fuzzyContent,
|
||||
};
|
||||
}
|
||||
|
||||
/** Strip UTF-8 BOM if present, return both the BOM (if any) and the text without it */
|
||||
export function stripBom(content: string): { bom: string; text: string } {
|
||||
return content.startsWith("\uFEFF")
|
||||
? { bom: "\uFEFF", text: content.slice(1) }
|
||||
: { bom: "", text: content };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a unified diff string with line numbers and context.
|
||||
* Returns both the diff string and the first changed line number (in the new file).
|
||||
*/
|
||||
export function generateDiffString(
|
||||
oldContent: string,
|
||||
newContent: string,
|
||||
contextLines = 4,
|
||||
): { diff: string; firstChangedLine: number | undefined } {
|
||||
const parts = Diff.diffLines(oldContent, newContent);
|
||||
const output: string[] = [];
|
||||
|
||||
const oldLines = oldContent.split("\n");
|
||||
const newLines = newContent.split("\n");
|
||||
const maxLineNum = Math.max(oldLines.length, newLines.length);
|
||||
const lineNumWidth = String(maxLineNum).length;
|
||||
|
||||
let oldLineNum = 1;
|
||||
let newLineNum = 1;
|
||||
let lastWasChange = false;
|
||||
let firstChangedLine: number | undefined;
|
||||
|
||||
for (let i = 0; i < parts.length; i++) {
|
||||
const part = parts[i];
|
||||
const raw = part.value.split("\n");
|
||||
if (raw[raw.length - 1] === "") {
|
||||
raw.pop();
|
||||
}
|
||||
|
||||
if (part.added || part.removed) {
|
||||
// Capture the first changed line (in the new file)
|
||||
if (firstChangedLine === undefined) {
|
||||
firstChangedLine = newLineNum;
|
||||
}
|
||||
|
||||
// Show the change
|
||||
for (const line of raw) {
|
||||
if (part.added) {
|
||||
const lineNum = String(newLineNum).padStart(lineNumWidth, " ");
|
||||
output.push(`+${lineNum} ${line}`);
|
||||
newLineNum++;
|
||||
} else {
|
||||
// removed
|
||||
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
|
||||
output.push(`-${lineNum} ${line}`);
|
||||
oldLineNum++;
|
||||
}
|
||||
}
|
||||
lastWasChange = true;
|
||||
} else {
|
||||
// Context lines - only show a few before/after changes
|
||||
const nextPartIsChange =
|
||||
i < parts.length - 1 && (parts[i + 1].added || parts[i + 1].removed);
|
||||
|
||||
if (lastWasChange || nextPartIsChange) {
|
||||
// Show context
|
||||
let linesToShow = raw;
|
||||
let skipStart = 0;
|
||||
let skipEnd = 0;
|
||||
|
||||
if (!lastWasChange) {
|
||||
// Show only last N lines as leading context
|
||||
skipStart = Math.max(0, raw.length - contextLines);
|
||||
linesToShow = raw.slice(skipStart);
|
||||
}
|
||||
|
||||
if (!nextPartIsChange && linesToShow.length > contextLines) {
|
||||
// Show only first N lines as trailing context
|
||||
skipEnd = linesToShow.length - contextLines;
|
||||
linesToShow = linesToShow.slice(0, contextLines);
|
||||
}
|
||||
|
||||
// Add ellipsis if we skipped lines at start
|
||||
if (skipStart > 0) {
|
||||
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
|
||||
// Update line numbers for the skipped leading context
|
||||
oldLineNum += skipStart;
|
||||
newLineNum += skipStart;
|
||||
}
|
||||
|
||||
for (const line of linesToShow) {
|
||||
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
|
||||
output.push(` ${lineNum} ${line}`);
|
||||
oldLineNum++;
|
||||
newLineNum++;
|
||||
}
|
||||
|
||||
// Add ellipsis if we skipped lines at end
|
||||
if (skipEnd > 0) {
|
||||
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
|
||||
// Update line numbers for the skipped trailing context
|
||||
oldLineNum += skipEnd;
|
||||
newLineNum += skipEnd;
|
||||
}
|
||||
} else {
|
||||
// Skip these context lines entirely
|
||||
oldLineNum += raw.length;
|
||||
newLineNum += raw.length;
|
||||
}
|
||||
|
||||
lastWasChange = false;
|
||||
}
|
||||
}
|
||||
|
||||
return { diff: output.join("\n"), firstChangedLine };
|
||||
}
|
||||
|
||||
export interface EditDiffResult {
|
||||
diff: string;
|
||||
firstChangedLine: number | undefined;
|
||||
}
|
||||
|
||||
export interface EditDiffError {
|
||||
error: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the diff for an edit operation without applying it.
|
||||
* Used for preview rendering in the TUI before the tool executes.
|
||||
*/
|
||||
export async function computeEditDiff(
|
||||
path: string,
|
||||
oldText: string,
|
||||
newText: string,
|
||||
cwd: string,
|
||||
): Promise<EditDiffResult | EditDiffError> {
|
||||
const absolutePath = resolveToCwd(path, cwd);
|
||||
|
||||
try {
|
||||
// Check if file exists and is readable
|
||||
try {
|
||||
await access(absolutePath, constants.R_OK);
|
||||
} catch {
|
||||
return { error: `File not found: ${path}` };
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const rawContent = await readFile(absolutePath, "utf-8");
|
||||
|
||||
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
|
||||
const { text: content } = stripBom(rawContent);
|
||||
|
||||
const normalizedContent = normalizeToLF(content);
|
||||
const normalizedOldText = normalizeToLF(oldText);
|
||||
const normalizedNewText = normalizeToLF(newText);
|
||||
|
||||
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
|
||||
const matchResult = fuzzyFindText(normalizedContent, normalizedOldText);
|
||||
|
||||
if (!matchResult.found) {
|
||||
return {
|
||||
error: `Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Count occurrences using fuzzy-normalized content for consistency
|
||||
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
|
||||
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
|
||||
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
|
||||
|
||||
if (occurrences > 1) {
|
||||
return {
|
||||
error: `Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Compute the new content using the matched position
|
||||
// When fuzzy matching was used, contentForReplacement is the normalized version
|
||||
const baseContent = matchResult.contentForReplacement;
|
||||
const newContent =
|
||||
baseContent.substring(0, matchResult.index) +
|
||||
normalizedNewText +
|
||||
baseContent.substring(matchResult.index + matchResult.matchLength);
|
||||
|
||||
// Check if it would actually change anything
|
||||
if (baseContent === newContent) {
|
||||
return {
|
||||
error: `No changes would be made to ${path}. The replacement produces identical content.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Generate the diff
|
||||
return generateDiffString(baseContent, newContent);
|
||||
} catch (err) {
|
||||
return { error: err instanceof Error ? err.message : String(err) };
|
||||
}
|
||||
}
|
||||
253
packages/coding-agent/src/core/tools/edit.ts
Normal file
253
packages/coding-agent/src/core/tools/edit.ts
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { constants } from "fs";
|
||||
import {
|
||||
access as fsAccess,
|
||||
readFile as fsReadFile,
|
||||
writeFile as fsWriteFile,
|
||||
} from "fs/promises";
|
||||
import {
|
||||
detectLineEnding,
|
||||
fuzzyFindText,
|
||||
generateDiffString,
|
||||
normalizeForFuzzyMatch,
|
||||
normalizeToLF,
|
||||
restoreLineEndings,
|
||||
stripBom,
|
||||
} from "./edit-diff.js";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
|
||||
const editSchema = Type.Object({
|
||||
path: Type.String({
|
||||
description: "Path to the file to edit (relative or absolute)",
|
||||
}),
|
||||
oldText: Type.String({
|
||||
description: "Exact text to find and replace (must match exactly)",
|
||||
}),
|
||||
newText: Type.String({
|
||||
description: "New text to replace the old text with",
|
||||
}),
|
||||
});
|
||||
|
||||
export type EditToolInput = Static<typeof editSchema>;
|
||||
|
||||
export interface EditToolDetails {
|
||||
/** Unified diff of the changes made */
|
||||
diff: string;
|
||||
/** Line number of the first change in the new file (for editor navigation) */
|
||||
firstChangedLine?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the edit tool.
|
||||
* Override these to delegate file editing to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface EditOperations {
|
||||
/** Read file contents as a Buffer */
|
||||
readFile: (absolutePath: string) => Promise<Buffer>;
|
||||
/** Write content to a file */
|
||||
writeFile: (absolutePath: string, content: string) => Promise<void>;
|
||||
/** Check if file is readable and writable (throw if not) */
|
||||
access: (absolutePath: string) => Promise<void>;
|
||||
}
|
||||
|
||||
const defaultEditOperations: EditOperations = {
|
||||
readFile: (path) => fsReadFile(path),
|
||||
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
|
||||
access: (path) => fsAccess(path, constants.R_OK | constants.W_OK),
|
||||
};
|
||||
|
||||
export interface EditToolOptions {
|
||||
/** Custom operations for file editing. Default: local filesystem */
|
||||
operations?: EditOperations;
|
||||
}
|
||||
|
||||
export function createEditTool(
|
||||
cwd: string,
|
||||
options?: EditToolOptions,
|
||||
): AgentTool<typeof editSchema> {
|
||||
const ops = options?.operations ?? defaultEditOperations;
|
||||
|
||||
return {
|
||||
name: "edit",
|
||||
label: "edit",
|
||||
description:
|
||||
"Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits.",
|
||||
parameters: editSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
path,
|
||||
oldText,
|
||||
newText,
|
||||
}: { path: string; oldText: string; newText: string },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const absolutePath = resolveToCwd(path, cwd);
|
||||
|
||||
return new Promise<{
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: EditToolDetails | undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let aborted = false;
|
||||
|
||||
// Set up abort handler
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
reject(new Error("Operation aborted"));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
// Perform the edit operation
|
||||
(async () => {
|
||||
try {
|
||||
// Check if file exists
|
||||
try {
|
||||
await ops.access(absolutePath);
|
||||
} catch {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(new Error(`File not found: ${path}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if aborted before reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const buffer = await ops.readFile(absolutePath);
|
||||
const rawContent = buffer.toString("utf-8");
|
||||
|
||||
// Check if aborted after reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
|
||||
const { bom, text: content } = stripBom(rawContent);
|
||||
|
||||
const originalEnding = detectLineEnding(content);
|
||||
const normalizedContent = normalizeToLF(content);
|
||||
const normalizedOldText = normalizeToLF(oldText);
|
||||
const normalizedNewText = normalizeToLF(newText);
|
||||
|
||||
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
|
||||
const matchResult = fuzzyFindText(
|
||||
normalizedContent,
|
||||
normalizedOldText,
|
||||
);
|
||||
|
||||
if (!matchResult.found) {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(
|
||||
new Error(
|
||||
`Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Count occurrences using fuzzy-normalized content for consistency
|
||||
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
|
||||
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
|
||||
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
|
||||
|
||||
if (occurrences > 1) {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(
|
||||
new Error(
|
||||
`Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if aborted before writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Perform replacement using the matched text position
|
||||
// When fuzzy matching was used, contentForReplacement is the normalized version
|
||||
const baseContent = matchResult.contentForReplacement;
|
||||
const newContent =
|
||||
baseContent.substring(0, matchResult.index) +
|
||||
normalizedNewText +
|
||||
baseContent.substring(
|
||||
matchResult.index + matchResult.matchLength,
|
||||
);
|
||||
|
||||
// Verify the replacement actually changed something
|
||||
if (baseContent === newContent) {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(
|
||||
new Error(
|
||||
`No changes made to ${path}. The replacement produced identical content. This might indicate an issue with special characters or the text not existing as expected.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const finalContent =
|
||||
bom + restoreLineEndings(newContent, originalEnding);
|
||||
await ops.writeFile(absolutePath, finalContent);
|
||||
|
||||
// Check if aborted after writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
const diffResult = generateDiffString(baseContent, newContent);
|
||||
resolve({
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Successfully replaced text in ${path}.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
diff: diffResult.diff,
|
||||
firstChangedLine: diffResult.firstChangedLine,
|
||||
},
|
||||
});
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
if (!aborted) {
|
||||
reject(error);
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default edit tool using process.cwd() - for backwards compatibility */
|
||||
export const editTool = createEditTool(process.cwd());
|
||||
308
packages/coding-agent/src/core/tools/find.ts
Normal file
308
packages/coding-agent/src/core/tools/find.ts
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { spawnSync } from "child_process";
|
||||
import { existsSync } from "fs";
|
||||
import { globSync } from "glob";
|
||||
import path from "path";
|
||||
import { ensureTool } from "../../utils/tools-manager.js";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
} from "./truncate.js";
|
||||
|
||||
const findSchema = Type.Object({
|
||||
pattern: Type.String({
|
||||
description:
|
||||
"Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'",
|
||||
}),
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: "Directory to search in (default: current directory)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({ description: "Maximum number of results (default: 1000)" }),
|
||||
),
|
||||
});
|
||||
|
||||
export type FindToolInput = Static<typeof findSchema>;
|
||||
|
||||
const DEFAULT_LIMIT = 1000;
|
||||
|
||||
export interface FindToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
resultLimitReached?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the find tool.
|
||||
* Override these to delegate file search to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface FindOperations {
|
||||
/** Check if path exists */
|
||||
exists: (absolutePath: string) => Promise<boolean> | boolean;
|
||||
/** Find files matching glob pattern. Returns relative paths. */
|
||||
glob: (
|
||||
pattern: string,
|
||||
cwd: string,
|
||||
options: { ignore: string[]; limit: number },
|
||||
) => Promise<string[]> | string[];
|
||||
}
|
||||
|
||||
const defaultFindOperations: FindOperations = {
|
||||
exists: existsSync,
|
||||
glob: (_pattern, _searchCwd, _options) => {
|
||||
// This is a placeholder - actual fd execution happens in execute
|
||||
return [];
|
||||
},
|
||||
};
|
||||
|
||||
export interface FindToolOptions {
|
||||
/** Custom operations for find. Default: local filesystem + fd */
|
||||
operations?: FindOperations;
|
||||
}
|
||||
|
||||
export function createFindTool(
|
||||
cwd: string,
|
||||
options?: FindToolOptions,
|
||||
): AgentTool<typeof findSchema> {
|
||||
const customOps = options?.operations;
|
||||
|
||||
return {
|
||||
name: "find",
|
||||
label: "find",
|
||||
description: `Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} results or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
|
||||
parameters: findSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
pattern,
|
||||
path: searchDir,
|
||||
limit,
|
||||
}: { pattern: string; path?: string; limit?: number },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
const onAbort = () => reject(new Error("Operation aborted"));
|
||||
signal?.addEventListener("abort", onAbort, { once: true });
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const searchPath = resolveToCwd(searchDir || ".", cwd);
|
||||
const effectiveLimit = limit ?? DEFAULT_LIMIT;
|
||||
const ops = customOps ?? defaultFindOperations;
|
||||
|
||||
// If custom operations provided with glob, use that
|
||||
if (customOps?.glob) {
|
||||
if (!(await ops.exists(searchPath))) {
|
||||
reject(new Error(`Path not found: ${searchPath}`));
|
||||
return;
|
||||
}
|
||||
|
||||
const results = await ops.glob(pattern, searchPath, {
|
||||
ignore: ["**/node_modules/**", "**/.git/**"],
|
||||
limit: effectiveLimit,
|
||||
});
|
||||
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
|
||||
if (results.length === 0) {
|
||||
resolve({
|
||||
content: [
|
||||
{ type: "text", text: "No files found matching pattern" },
|
||||
],
|
||||
details: undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Relativize paths
|
||||
const relativized = results.map((p) => {
|
||||
if (p.startsWith(searchPath)) {
|
||||
return p.slice(searchPath.length + 1);
|
||||
}
|
||||
return path.relative(searchPath, p);
|
||||
});
|
||||
|
||||
const resultLimitReached = relativized.length >= effectiveLimit;
|
||||
const rawOutput = relativized.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let resultOutput = truncation.content;
|
||||
const details: FindToolDetails = {};
|
||||
const notices: string[] = [];
|
||||
|
||||
if (resultLimitReached) {
|
||||
notices.push(`${effectiveLimit} results limit reached`);
|
||||
details.resultLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
resultOutput += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [{ type: "text", text: resultOutput }],
|
||||
details: Object.keys(details).length > 0 ? details : undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Default: use fd
|
||||
const fdPath = await ensureTool("fd", true);
|
||||
if (!fdPath) {
|
||||
reject(
|
||||
new Error("fd is not available and could not be downloaded"),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Build fd arguments
|
||||
const args: string[] = [
|
||||
"--glob",
|
||||
"--color=never",
|
||||
"--hidden",
|
||||
"--max-results",
|
||||
String(effectiveLimit),
|
||||
];
|
||||
|
||||
// Include .gitignore files
|
||||
const gitignoreFiles = new Set<string>();
|
||||
const rootGitignore = path.join(searchPath, ".gitignore");
|
||||
if (existsSync(rootGitignore)) {
|
||||
gitignoreFiles.add(rootGitignore);
|
||||
}
|
||||
|
||||
try {
|
||||
const nestedGitignores = globSync("**/.gitignore", {
|
||||
cwd: searchPath,
|
||||
dot: true,
|
||||
absolute: true,
|
||||
ignore: ["**/node_modules/**", "**/.git/**"],
|
||||
});
|
||||
for (const file of nestedGitignores) {
|
||||
gitignoreFiles.add(file);
|
||||
}
|
||||
} catch {
|
||||
// Ignore glob errors
|
||||
}
|
||||
|
||||
for (const gitignorePath of gitignoreFiles) {
|
||||
args.push("--ignore-file", gitignorePath);
|
||||
}
|
||||
|
||||
args.push(pattern, searchPath);
|
||||
|
||||
const result = spawnSync(fdPath, args, {
|
||||
encoding: "utf-8",
|
||||
maxBuffer: 10 * 1024 * 1024,
|
||||
});
|
||||
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
|
||||
if (result.error) {
|
||||
reject(new Error(`Failed to run fd: ${result.error.message}`));
|
||||
return;
|
||||
}
|
||||
|
||||
const output = result.stdout?.trim() || "";
|
||||
|
||||
if (result.status !== 0) {
|
||||
const errorMsg =
|
||||
result.stderr?.trim() || `fd exited with code ${result.status}`;
|
||||
if (!output) {
|
||||
reject(new Error(errorMsg));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!output) {
|
||||
resolve({
|
||||
content: [
|
||||
{ type: "text", text: "No files found matching pattern" },
|
||||
],
|
||||
details: undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = output.split("\n");
|
||||
const relativized: string[] = [];
|
||||
|
||||
for (const rawLine of lines) {
|
||||
const line = rawLine.replace(/\r$/, "").trim();
|
||||
if (!line) continue;
|
||||
|
||||
const hadTrailingSlash =
|
||||
line.endsWith("/") || line.endsWith("\\");
|
||||
let relativePath = line;
|
||||
if (line.startsWith(searchPath)) {
|
||||
relativePath = line.slice(searchPath.length + 1);
|
||||
} else {
|
||||
relativePath = path.relative(searchPath, line);
|
||||
}
|
||||
|
||||
if (hadTrailingSlash && !relativePath.endsWith("/")) {
|
||||
relativePath += "/";
|
||||
}
|
||||
|
||||
relativized.push(relativePath);
|
||||
}
|
||||
|
||||
const resultLimitReached = relativized.length >= effectiveLimit;
|
||||
const rawOutput = relativized.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let resultOutput = truncation.content;
|
||||
const details: FindToolDetails = {};
|
||||
const notices: string[] = [];
|
||||
|
||||
if (resultLimitReached) {
|
||||
notices.push(
|
||||
`${effectiveLimit} results limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
|
||||
);
|
||||
details.resultLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
resultOutput += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [{ type: "text", text: resultOutput }],
|
||||
details: Object.keys(details).length > 0 ? details : undefined,
|
||||
});
|
||||
} catch (e: any) {
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
reject(e);
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default find tool using process.cwd() - for backwards compatibility */
|
||||
export const findTool = createFindTool(process.cwd());
|
||||
412
packages/coding-agent/src/core/tools/grep.ts
Normal file
412
packages/coding-agent/src/core/tools/grep.ts
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
import { createInterface } from "node:readline";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { spawn } from "child_process";
|
||||
import { readFileSync, statSync } from "fs";
|
||||
import path from "path";
|
||||
import { ensureTool } from "../../utils/tools-manager.js";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
formatSize,
|
||||
GREP_MAX_LINE_LENGTH,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
truncateLine,
|
||||
} from "./truncate.js";
|
||||
|
||||
const grepSchema = Type.Object({
|
||||
pattern: Type.String({
|
||||
description: "Search pattern (regex or literal string)",
|
||||
}),
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: "Directory or file to search (default: current directory)",
|
||||
}),
|
||||
),
|
||||
glob: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'",
|
||||
}),
|
||||
),
|
||||
ignoreCase: Type.Optional(
|
||||
Type.Boolean({ description: "Case-insensitive search (default: false)" }),
|
||||
),
|
||||
literal: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"Treat pattern as literal string instead of regex (default: false)",
|
||||
}),
|
||||
),
|
||||
context: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Number of lines to show before and after each match (default: 0)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Maximum number of matches to return (default: 100)",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export type GrepToolInput = Static<typeof grepSchema>;
|
||||
|
||||
const DEFAULT_LIMIT = 100;
|
||||
|
||||
export interface GrepToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
matchLimitReached?: number;
|
||||
linesTruncated?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the grep tool.
|
||||
* Override these to delegate search to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface GrepOperations {
|
||||
/** Check if path is a directory. Throws if path doesn't exist. */
|
||||
isDirectory: (absolutePath: string) => Promise<boolean> | boolean;
|
||||
/** Read file contents for context lines */
|
||||
readFile: (absolutePath: string) => Promise<string> | string;
|
||||
}
|
||||
|
||||
const defaultGrepOperations: GrepOperations = {
|
||||
isDirectory: (p) => statSync(p).isDirectory(),
|
||||
readFile: (p) => readFileSync(p, "utf-8"),
|
||||
};
|
||||
|
||||
export interface GrepToolOptions {
|
||||
/** Custom operations for grep. Default: local filesystem + ripgrep */
|
||||
operations?: GrepOperations;
|
||||
}
|
||||
|
||||
export function createGrepTool(
|
||||
cwd: string,
|
||||
options?: GrepToolOptions,
|
||||
): AgentTool<typeof grepSchema> {
|
||||
const customOps = options?.operations;
|
||||
|
||||
return {
|
||||
name: "grep",
|
||||
label: "grep",
|
||||
description: `Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} matches or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Long lines are truncated to ${GREP_MAX_LINE_LENGTH} chars.`,
|
||||
parameters: grepSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
pattern,
|
||||
path: searchDir,
|
||||
glob,
|
||||
ignoreCase,
|
||||
literal,
|
||||
context,
|
||||
limit,
|
||||
}: {
|
||||
pattern: string;
|
||||
path?: string;
|
||||
glob?: string;
|
||||
ignoreCase?: boolean;
|
||||
literal?: boolean;
|
||||
context?: number;
|
||||
limit?: number;
|
||||
},
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let settled = false;
|
||||
const settle = (fn: () => void) => {
|
||||
if (!settled) {
|
||||
settled = true;
|
||||
fn();
|
||||
}
|
||||
};
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const rgPath = await ensureTool("rg", true);
|
||||
if (!rgPath) {
|
||||
settle(() =>
|
||||
reject(
|
||||
new Error(
|
||||
"ripgrep (rg) is not available and could not be downloaded",
|
||||
),
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const searchPath = resolveToCwd(searchDir || ".", cwd);
|
||||
const ops = customOps ?? defaultGrepOperations;
|
||||
|
||||
let isDirectory: boolean;
|
||||
try {
|
||||
isDirectory = await ops.isDirectory(searchPath);
|
||||
} catch (_err) {
|
||||
settle(() => reject(new Error(`Path not found: ${searchPath}`)));
|
||||
return;
|
||||
}
|
||||
const contextValue = context && context > 0 ? context : 0;
|
||||
const effectiveLimit = Math.max(1, limit ?? DEFAULT_LIMIT);
|
||||
|
||||
const formatPath = (filePath: string): string => {
|
||||
if (isDirectory) {
|
||||
const relative = path.relative(searchPath, filePath);
|
||||
if (relative && !relative.startsWith("..")) {
|
||||
return relative.replace(/\\/g, "/");
|
||||
}
|
||||
}
|
||||
return path.basename(filePath);
|
||||
};
|
||||
|
||||
const fileCache = new Map<string, string[]>();
|
||||
const getFileLines = async (
|
||||
filePath: string,
|
||||
): Promise<string[]> => {
|
||||
let lines = fileCache.get(filePath);
|
||||
if (!lines) {
|
||||
try {
|
||||
const content = await ops.readFile(filePath);
|
||||
lines = content
|
||||
.replace(/\r\n/g, "\n")
|
||||
.replace(/\r/g, "\n")
|
||||
.split("\n");
|
||||
} catch {
|
||||
lines = [];
|
||||
}
|
||||
fileCache.set(filePath, lines);
|
||||
}
|
||||
return lines;
|
||||
};
|
||||
|
||||
const args: string[] = [
|
||||
"--json",
|
||||
"--line-number",
|
||||
"--color=never",
|
||||
"--hidden",
|
||||
];
|
||||
|
||||
if (ignoreCase) {
|
||||
args.push("--ignore-case");
|
||||
}
|
||||
|
||||
if (literal) {
|
||||
args.push("--fixed-strings");
|
||||
}
|
||||
|
||||
if (glob) {
|
||||
args.push("--glob", glob);
|
||||
}
|
||||
|
||||
args.push(pattern, searchPath);
|
||||
|
||||
const child = spawn(rgPath, args, {
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
const rl = createInterface({ input: child.stdout });
|
||||
let stderr = "";
|
||||
let matchCount = 0;
|
||||
let matchLimitReached = false;
|
||||
let linesTruncated = false;
|
||||
let aborted = false;
|
||||
let killedDueToLimit = false;
|
||||
const outputLines: string[] = [];
|
||||
|
||||
const cleanup = () => {
|
||||
rl.close();
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
};
|
||||
|
||||
const stopChild = (dueToLimit: boolean = false) => {
|
||||
if (!child.killed) {
|
||||
killedDueToLimit = dueToLimit;
|
||||
child.kill();
|
||||
}
|
||||
};
|
||||
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
stopChild();
|
||||
};
|
||||
|
||||
signal?.addEventListener("abort", onAbort, { once: true });
|
||||
|
||||
child.stderr?.on("data", (chunk) => {
|
||||
stderr += chunk.toString();
|
||||
});
|
||||
|
||||
const formatBlock = async (
|
||||
filePath: string,
|
||||
lineNumber: number,
|
||||
): Promise<string[]> => {
|
||||
const relativePath = formatPath(filePath);
|
||||
const lines = await getFileLines(filePath);
|
||||
if (!lines.length) {
|
||||
return [`${relativePath}:${lineNumber}: (unable to read file)`];
|
||||
}
|
||||
|
||||
const block: string[] = [];
|
||||
const start =
|
||||
contextValue > 0
|
||||
? Math.max(1, lineNumber - contextValue)
|
||||
: lineNumber;
|
||||
const end =
|
||||
contextValue > 0
|
||||
? Math.min(lines.length, lineNumber + contextValue)
|
||||
: lineNumber;
|
||||
|
||||
for (let current = start; current <= end; current++) {
|
||||
const lineText = lines[current - 1] ?? "";
|
||||
const sanitized = lineText.replace(/\r/g, "");
|
||||
const isMatchLine = current === lineNumber;
|
||||
|
||||
// Truncate long lines
|
||||
const { text: truncatedText, wasTruncated } =
|
||||
truncateLine(sanitized);
|
||||
if (wasTruncated) {
|
||||
linesTruncated = true;
|
||||
}
|
||||
|
||||
if (isMatchLine) {
|
||||
block.push(`${relativePath}:${current}: ${truncatedText}`);
|
||||
} else {
|
||||
block.push(`${relativePath}-${current}- ${truncatedText}`);
|
||||
}
|
||||
}
|
||||
|
||||
return block;
|
||||
};
|
||||
|
||||
// Collect matches during streaming, format after
|
||||
const matches: Array<{ filePath: string; lineNumber: number }> = [];
|
||||
|
||||
rl.on("line", (line) => {
|
||||
if (!line.trim() || matchCount >= effectiveLimit) {
|
||||
return;
|
||||
}
|
||||
|
||||
let event: any;
|
||||
try {
|
||||
event = JSON.parse(line);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === "match") {
|
||||
matchCount++;
|
||||
const filePath = event.data?.path?.text;
|
||||
const lineNumber = event.data?.line_number;
|
||||
|
||||
if (filePath && typeof lineNumber === "number") {
|
||||
matches.push({ filePath, lineNumber });
|
||||
}
|
||||
|
||||
if (matchCount >= effectiveLimit) {
|
||||
matchLimitReached = true;
|
||||
stopChild(true);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
child.on("error", (error) => {
|
||||
cleanup();
|
||||
settle(() =>
|
||||
reject(new Error(`Failed to run ripgrep: ${error.message}`)),
|
||||
);
|
||||
});
|
||||
|
||||
child.on("close", async (code) => {
|
||||
cleanup();
|
||||
|
||||
if (aborted) {
|
||||
settle(() => reject(new Error("Operation aborted")));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!killedDueToLimit && code !== 0 && code !== 1) {
|
||||
const errorMsg =
|
||||
stderr.trim() || `ripgrep exited with code ${code}`;
|
||||
settle(() => reject(new Error(errorMsg)));
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchCount === 0) {
|
||||
settle(() =>
|
||||
resolve({
|
||||
content: [{ type: "text", text: "No matches found" }],
|
||||
details: undefined,
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Format matches (async to support remote file reading)
|
||||
for (const match of matches) {
|
||||
const block = await formatBlock(
|
||||
match.filePath,
|
||||
match.lineNumber,
|
||||
);
|
||||
outputLines.push(...block);
|
||||
}
|
||||
|
||||
// Apply byte truncation (no line limit since we already have match limit)
|
||||
const rawOutput = outputLines.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let output = truncation.content;
|
||||
const details: GrepToolDetails = {};
|
||||
|
||||
// Build notices
|
||||
const notices: string[] = [];
|
||||
|
||||
if (matchLimitReached) {
|
||||
notices.push(
|
||||
`${effectiveLimit} matches limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
|
||||
);
|
||||
details.matchLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (linesTruncated) {
|
||||
notices.push(
|
||||
`Some lines truncated to ${GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines`,
|
||||
);
|
||||
details.linesTruncated = true;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
output += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
settle(() =>
|
||||
resolve({
|
||||
content: [{ type: "text", text: output }],
|
||||
details:
|
||||
Object.keys(details).length > 0 ? details : undefined,
|
||||
}),
|
||||
);
|
||||
});
|
||||
} catch (err) {
|
||||
settle(() => reject(err as Error));
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default grep tool using process.cwd() - for backwards compatibility */
|
||||
export const grepTool = createGrepTool(process.cwd());
|
||||
150
packages/coding-agent/src/core/tools/index.ts
Normal file
150
packages/coding-agent/src/core/tools/index.ts
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
export {
|
||||
type BashOperations,
|
||||
type BashSpawnContext,
|
||||
type BashSpawnHook,
|
||||
type BashToolDetails,
|
||||
type BashToolInput,
|
||||
type BashToolOptions,
|
||||
bashTool,
|
||||
createBashTool,
|
||||
} from "./bash.js";
|
||||
export {
|
||||
createEditTool,
|
||||
type EditOperations,
|
||||
type EditToolDetails,
|
||||
type EditToolInput,
|
||||
type EditToolOptions,
|
||||
editTool,
|
||||
} from "./edit.js";
|
||||
export {
|
||||
createFindTool,
|
||||
type FindOperations,
|
||||
type FindToolDetails,
|
||||
type FindToolInput,
|
||||
type FindToolOptions,
|
||||
findTool,
|
||||
} from "./find.js";
|
||||
export {
|
||||
createGrepTool,
|
||||
type GrepOperations,
|
||||
type GrepToolDetails,
|
||||
type GrepToolInput,
|
||||
type GrepToolOptions,
|
||||
grepTool,
|
||||
} from "./grep.js";
|
||||
export {
|
||||
createLsTool,
|
||||
type LsOperations,
|
||||
type LsToolDetails,
|
||||
type LsToolInput,
|
||||
type LsToolOptions,
|
||||
lsTool,
|
||||
} from "./ls.js";
|
||||
export {
|
||||
createReadTool,
|
||||
type ReadOperations,
|
||||
type ReadToolDetails,
|
||||
type ReadToolInput,
|
||||
type ReadToolOptions,
|
||||
readTool,
|
||||
} from "./read.js";
|
||||
export {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
type TruncationOptions,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
truncateLine,
|
||||
truncateTail,
|
||||
} from "./truncate.js";
|
||||
export {
|
||||
createWriteTool,
|
||||
type WriteOperations,
|
||||
type WriteToolInput,
|
||||
type WriteToolOptions,
|
||||
writeTool,
|
||||
} from "./write.js";
|
||||
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type BashToolOptions, bashTool, createBashTool } from "./bash.js";
|
||||
import { createEditTool, editTool } from "./edit.js";
|
||||
import { createFindTool, findTool } from "./find.js";
|
||||
import { createGrepTool, grepTool } from "./grep.js";
|
||||
import { createLsTool, lsTool } from "./ls.js";
|
||||
import { createReadTool, type ReadToolOptions, readTool } from "./read.js";
|
||||
import { createWriteTool, writeTool } from "./write.js";
|
||||
|
||||
/** Tool type (AgentTool from pi-ai) */
|
||||
export type Tool = AgentTool<any>;
|
||||
|
||||
// Default tools for full access mode (using process.cwd())
|
||||
export const codingTools: Tool[] = [readTool, bashTool, editTool, writeTool];
|
||||
|
||||
// Read-only tools for exploration without modification (using process.cwd())
|
||||
export const readOnlyTools: Tool[] = [readTool, grepTool, findTool, lsTool];
|
||||
|
||||
// All available tools (using process.cwd())
|
||||
export const allTools = {
|
||||
read: readTool,
|
||||
bash: bashTool,
|
||||
edit: editTool,
|
||||
write: writeTool,
|
||||
grep: grepTool,
|
||||
find: findTool,
|
||||
ls: lsTool,
|
||||
};
|
||||
|
||||
export type ToolName = keyof typeof allTools;
|
||||
|
||||
export interface ToolsOptions {
|
||||
/** Options for the read tool */
|
||||
read?: ReadToolOptions;
|
||||
/** Options for the bash tool */
|
||||
bash?: BashToolOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create coding tools configured for a specific working directory.
|
||||
*/
|
||||
export function createCodingTools(cwd: string, options?: ToolsOptions): Tool[] {
|
||||
return [
|
||||
createReadTool(cwd, options?.read),
|
||||
createBashTool(cwd, options?.bash),
|
||||
createEditTool(cwd),
|
||||
createWriteTool(cwd),
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create read-only tools configured for a specific working directory.
|
||||
*/
|
||||
export function createReadOnlyTools(
|
||||
cwd: string,
|
||||
options?: ToolsOptions,
|
||||
): Tool[] {
|
||||
return [
|
||||
createReadTool(cwd, options?.read),
|
||||
createGrepTool(cwd),
|
||||
createFindTool(cwd),
|
||||
createLsTool(cwd),
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create all tools configured for a specific working directory.
|
||||
*/
|
||||
export function createAllTools(
|
||||
cwd: string,
|
||||
options?: ToolsOptions,
|
||||
): Record<ToolName, Tool> {
|
||||
return {
|
||||
read: createReadTool(cwd, options?.read),
|
||||
bash: createBashTool(cwd, options?.bash),
|
||||
edit: createEditTool(cwd),
|
||||
write: createWriteTool(cwd),
|
||||
grep: createGrepTool(cwd),
|
||||
find: createFindTool(cwd),
|
||||
ls: createLsTool(cwd),
|
||||
};
|
||||
}
|
||||
197
packages/coding-agent/src/core/tools/ls.ts
Normal file
197
packages/coding-agent/src/core/tools/ls.ts
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { existsSync, readdirSync, statSync } from "fs";
|
||||
import nodePath from "path";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
} from "./truncate.js";
|
||||
|
||||
const lsSchema = Type.Object({
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: "Directory to list (default: current directory)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Maximum number of entries to return (default: 500)",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export type LsToolInput = Static<typeof lsSchema>;
|
||||
|
||||
const DEFAULT_LIMIT = 500;
|
||||
|
||||
export interface LsToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
entryLimitReached?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the ls tool.
|
||||
* Override these to delegate directory listing to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface LsOperations {
|
||||
/** Check if path exists */
|
||||
exists: (absolutePath: string) => Promise<boolean> | boolean;
|
||||
/** Get file/directory stats. Throws if not found. */
|
||||
stat: (
|
||||
absolutePath: string,
|
||||
) => Promise<{ isDirectory: () => boolean }> | { isDirectory: () => boolean };
|
||||
/** Read directory entries */
|
||||
readdir: (absolutePath: string) => Promise<string[]> | string[];
|
||||
}
|
||||
|
||||
const defaultLsOperations: LsOperations = {
|
||||
exists: existsSync,
|
||||
stat: statSync,
|
||||
readdir: readdirSync,
|
||||
};
|
||||
|
||||
export interface LsToolOptions {
|
||||
/** Custom operations for directory listing. Default: local filesystem */
|
||||
operations?: LsOperations;
|
||||
}
|
||||
|
||||
export function createLsTool(
|
||||
cwd: string,
|
||||
options?: LsToolOptions,
|
||||
): AgentTool<typeof lsSchema> {
|
||||
const ops = options?.operations ?? defaultLsOperations;
|
||||
|
||||
return {
|
||||
name: "ls",
|
||||
label: "ls",
|
||||
description: `List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to ${DEFAULT_LIMIT} entries or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
|
||||
parameters: lsSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{ path, limit }: { path?: string; limit?: number },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
const onAbort = () => reject(new Error("Operation aborted"));
|
||||
signal?.addEventListener("abort", onAbort, { once: true });
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const dirPath = resolveToCwd(path || ".", cwd);
|
||||
const effectiveLimit = limit ?? DEFAULT_LIMIT;
|
||||
|
||||
// Check if path exists
|
||||
if (!(await ops.exists(dirPath))) {
|
||||
reject(new Error(`Path not found: ${dirPath}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if path is a directory
|
||||
const stat = await ops.stat(dirPath);
|
||||
if (!stat.isDirectory()) {
|
||||
reject(new Error(`Not a directory: ${dirPath}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Read directory entries
|
||||
let entries: string[];
|
||||
try {
|
||||
entries = await ops.readdir(dirPath);
|
||||
} catch (e: any) {
|
||||
reject(new Error(`Cannot read directory: ${e.message}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort alphabetically (case-insensitive)
|
||||
entries.sort((a, b) =>
|
||||
a.toLowerCase().localeCompare(b.toLowerCase()),
|
||||
);
|
||||
|
||||
// Format entries with directory indicators
|
||||
const results: string[] = [];
|
||||
let entryLimitReached = false;
|
||||
|
||||
for (const entry of entries) {
|
||||
if (results.length >= effectiveLimit) {
|
||||
entryLimitReached = true;
|
||||
break;
|
||||
}
|
||||
|
||||
const fullPath = nodePath.join(dirPath, entry);
|
||||
let suffix = "";
|
||||
|
||||
try {
|
||||
const entryStat = await ops.stat(fullPath);
|
||||
if (entryStat.isDirectory()) {
|
||||
suffix = "/";
|
||||
}
|
||||
} catch {
|
||||
// Skip entries we can't stat
|
||||
continue;
|
||||
}
|
||||
|
||||
results.push(entry + suffix);
|
||||
}
|
||||
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
|
||||
if (results.length === 0) {
|
||||
resolve({
|
||||
content: [{ type: "text", text: "(empty directory)" }],
|
||||
details: undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply byte truncation (no line limit since we already have entry limit)
|
||||
const rawOutput = results.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let output = truncation.content;
|
||||
const details: LsToolDetails = {};
|
||||
|
||||
// Build notices
|
||||
const notices: string[] = [];
|
||||
|
||||
if (entryLimitReached) {
|
||||
notices.push(
|
||||
`${effectiveLimit} entries limit reached. Use limit=${effectiveLimit * 2} for more`,
|
||||
);
|
||||
details.entryLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
output += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [{ type: "text", text: output }],
|
||||
details: Object.keys(details).length > 0 ? details : undefined,
|
||||
});
|
||||
} catch (e: any) {
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
reject(e);
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default ls tool using process.cwd() - for backwards compatibility */
|
||||
export const lsTool = createLsTool(process.cwd());
|
||||
94
packages/coding-agent/src/core/tools/path-utils.ts
Normal file
94
packages/coding-agent/src/core/tools/path-utils.ts
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import { accessSync, constants } from "node:fs";
|
||||
import * as os from "node:os";
|
||||
import { isAbsolute, resolve as resolvePath } from "node:path";
|
||||
|
||||
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
|
||||
const NARROW_NO_BREAK_SPACE = "\u202F";
|
||||
function normalizeUnicodeSpaces(str: string): string {
|
||||
return str.replace(UNICODE_SPACES, " ");
|
||||
}
|
||||
|
||||
function tryMacOSScreenshotPath(filePath: string): string {
|
||||
return filePath.replace(/ (AM|PM)\./g, `${NARROW_NO_BREAK_SPACE}$1.`);
|
||||
}
|
||||
|
||||
function tryNFDVariant(filePath: string): string {
|
||||
// macOS stores filenames in NFD (decomposed) form, try converting user input to NFD
|
||||
return filePath.normalize("NFD");
|
||||
}
|
||||
|
||||
function tryCurlyQuoteVariant(filePath: string): string {
|
||||
// macOS uses U+2019 (right single quotation mark) in screenshot names like "Capture d'écran"
|
||||
// Users typically type U+0027 (straight apostrophe)
|
||||
return filePath.replace(/'/g, "\u2019");
|
||||
}
|
||||
|
||||
function fileExists(filePath: string): boolean {
|
||||
try {
|
||||
accessSync(filePath, constants.F_OK);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeAtPrefix(filePath: string): string {
|
||||
return filePath.startsWith("@") ? filePath.slice(1) : filePath;
|
||||
}
|
||||
|
||||
export function expandPath(filePath: string): string {
|
||||
const normalized = normalizeUnicodeSpaces(normalizeAtPrefix(filePath));
|
||||
if (normalized === "~") {
|
||||
return os.homedir();
|
||||
}
|
||||
if (normalized.startsWith("~/")) {
|
||||
return os.homedir() + normalized.slice(1);
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a path relative to the given cwd.
|
||||
* Handles ~ expansion and absolute paths.
|
||||
*/
|
||||
export function resolveToCwd(filePath: string, cwd: string): string {
|
||||
const expanded = expandPath(filePath);
|
||||
if (isAbsolute(expanded)) {
|
||||
return expanded;
|
||||
}
|
||||
return resolvePath(cwd, expanded);
|
||||
}
|
||||
|
||||
export function resolveReadPath(filePath: string, cwd: string): string {
|
||||
const resolved = resolveToCwd(filePath, cwd);
|
||||
|
||||
if (fileExists(resolved)) {
|
||||
return resolved;
|
||||
}
|
||||
|
||||
// Try macOS AM/PM variant (narrow no-break space before AM/PM)
|
||||
const amPmVariant = tryMacOSScreenshotPath(resolved);
|
||||
if (amPmVariant !== resolved && fileExists(amPmVariant)) {
|
||||
return amPmVariant;
|
||||
}
|
||||
|
||||
// Try NFD variant (macOS stores filenames in NFD form)
|
||||
const nfdVariant = tryNFDVariant(resolved);
|
||||
if (nfdVariant !== resolved && fileExists(nfdVariant)) {
|
||||
return nfdVariant;
|
||||
}
|
||||
|
||||
// Try curly quote variant (macOS uses U+2019 in screenshot names)
|
||||
const curlyVariant = tryCurlyQuoteVariant(resolved);
|
||||
if (curlyVariant !== resolved && fileExists(curlyVariant)) {
|
||||
return curlyVariant;
|
||||
}
|
||||
|
||||
// Try combined NFD + curly quote (for French macOS screenshots like "Capture d'écran")
|
||||
const nfdCurlyVariant = tryCurlyQuoteVariant(nfdVariant);
|
||||
if (nfdCurlyVariant !== resolved && fileExists(nfdCurlyVariant)) {
|
||||
return nfdCurlyVariant;
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
265
packages/coding-agent/src/core/tools/read.ts
Normal file
265
packages/coding-agent/src/core/tools/read.ts
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { constants } from "fs";
|
||||
import { access as fsAccess, readFile as fsReadFile } from "fs/promises";
|
||||
import { formatDimensionNote, resizeImage } from "../../utils/image-resize.js";
|
||||
import { detectSupportedImageMimeTypeFromFile } from "../../utils/mime.js";
|
||||
import { resolveReadPath } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
} from "./truncate.js";
|
||||
|
||||
const readSchema = Type.Object({
|
||||
path: Type.String({
|
||||
description: "Path to the file to read (relative or absolute)",
|
||||
}),
|
||||
offset: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Line number to start reading from (1-indexed)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({ description: "Maximum number of lines to read" }),
|
||||
),
|
||||
});
|
||||
|
||||
export type ReadToolInput = Static<typeof readSchema>;
|
||||
|
||||
export interface ReadToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the read tool.
|
||||
* Override these to delegate file reading to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface ReadOperations {
|
||||
/** Read file contents as a Buffer */
|
||||
readFile: (absolutePath: string) => Promise<Buffer>;
|
||||
/** Check if file is readable (throw if not) */
|
||||
access: (absolutePath: string) => Promise<void>;
|
||||
/** Detect image MIME type, return null/undefined for non-images */
|
||||
detectImageMimeType?: (
|
||||
absolutePath: string,
|
||||
) => Promise<string | null | undefined>;
|
||||
}
|
||||
|
||||
const defaultReadOperations: ReadOperations = {
|
||||
readFile: (path) => fsReadFile(path),
|
||||
access: (path) => fsAccess(path, constants.R_OK),
|
||||
detectImageMimeType: detectSupportedImageMimeTypeFromFile,
|
||||
};
|
||||
|
||||
export interface ReadToolOptions {
|
||||
/** Whether to auto-resize images to 2000x2000 max. Default: true */
|
||||
autoResizeImages?: boolean;
|
||||
/** Custom operations for file reading. Default: local filesystem */
|
||||
operations?: ReadOperations;
|
||||
}
|
||||
|
||||
export function createReadTool(
|
||||
cwd: string,
|
||||
options?: ReadToolOptions,
|
||||
): AgentTool<typeof readSchema> {
|
||||
const autoResizeImages = options?.autoResizeImages ?? true;
|
||||
const ops = options?.operations ?? defaultReadOperations;
|
||||
|
||||
return {
|
||||
name: "read",
|
||||
label: "read",
|
||||
description: `Read the contents of a file. Supports text files and images (jpg, png, gif, webp). Images are sent as attachments. For text files, output is truncated to ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Use offset/limit for large files. When you need the full file, continue with offset until complete.`,
|
||||
parameters: readSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
path,
|
||||
offset,
|
||||
limit,
|
||||
}: { path: string; offset?: number; limit?: number },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const absolutePath = resolveReadPath(path, cwd);
|
||||
|
||||
return new Promise<{
|
||||
content: (TextContent | ImageContent)[];
|
||||
details: ReadToolDetails | undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let aborted = false;
|
||||
|
||||
// Set up abort handler
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
reject(new Error("Operation aborted"));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
// Perform the read operation
|
||||
(async () => {
|
||||
try {
|
||||
// Check if file exists
|
||||
await ops.access(absolutePath);
|
||||
|
||||
// Check if aborted before reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
const mimeType = ops.detectImageMimeType
|
||||
? await ops.detectImageMimeType(absolutePath)
|
||||
: undefined;
|
||||
|
||||
// Read the file based on type
|
||||
let content: (TextContent | ImageContent)[];
|
||||
let details: ReadToolDetails | undefined;
|
||||
|
||||
if (mimeType) {
|
||||
// Read as image (binary)
|
||||
const buffer = await ops.readFile(absolutePath);
|
||||
const base64 = buffer.toString("base64");
|
||||
|
||||
if (autoResizeImages) {
|
||||
// Resize image if needed
|
||||
const resized = await resizeImage({
|
||||
type: "image",
|
||||
data: base64,
|
||||
mimeType,
|
||||
});
|
||||
const dimensionNote = formatDimensionNote(resized);
|
||||
|
||||
let textNote = `Read image file [${resized.mimeType}]`;
|
||||
if (dimensionNote) {
|
||||
textNote += `\n${dimensionNote}`;
|
||||
}
|
||||
|
||||
content = [
|
||||
{ type: "text", text: textNote },
|
||||
{
|
||||
type: "image",
|
||||
data: resized.data,
|
||||
mimeType: resized.mimeType,
|
||||
},
|
||||
];
|
||||
} else {
|
||||
const textNote = `Read image file [${mimeType}]`;
|
||||
content = [
|
||||
{ type: "text", text: textNote },
|
||||
{ type: "image", data: base64, mimeType },
|
||||
];
|
||||
}
|
||||
} else {
|
||||
// Read as text
|
||||
const buffer = await ops.readFile(absolutePath);
|
||||
const textContent = buffer.toString("utf-8");
|
||||
const allLines = textContent.split("\n");
|
||||
const totalFileLines = allLines.length;
|
||||
|
||||
// Apply offset if specified (1-indexed to 0-indexed)
|
||||
const startLine = offset ? Math.max(0, offset - 1) : 0;
|
||||
const startLineDisplay = startLine + 1; // For display (1-indexed)
|
||||
|
||||
// Check if offset is out of bounds
|
||||
if (startLine >= allLines.length) {
|
||||
throw new Error(
|
||||
`Offset ${offset} is beyond end of file (${allLines.length} lines total)`,
|
||||
);
|
||||
}
|
||||
|
||||
// If limit is specified by user, use it; otherwise we'll let truncateHead decide
|
||||
let selectedContent: string;
|
||||
let userLimitedLines: number | undefined;
|
||||
if (limit !== undefined) {
|
||||
const endLine = Math.min(startLine + limit, allLines.length);
|
||||
selectedContent = allLines.slice(startLine, endLine).join("\n");
|
||||
userLimitedLines = endLine - startLine;
|
||||
} else {
|
||||
selectedContent = allLines.slice(startLine).join("\n");
|
||||
}
|
||||
|
||||
// Apply truncation (respects both line and byte limits)
|
||||
const truncation = truncateHead(selectedContent);
|
||||
|
||||
let outputText: string;
|
||||
|
||||
if (truncation.firstLineExceedsLimit) {
|
||||
// First line at offset exceeds 30KB - tell model to use bash
|
||||
const firstLineSize = formatSize(
|
||||
Buffer.byteLength(allLines[startLine], "utf-8"),
|
||||
);
|
||||
outputText = `[Line ${startLineDisplay} is ${firstLineSize}, exceeds ${formatSize(DEFAULT_MAX_BYTES)} limit. Use bash: sed -n '${startLineDisplay}p' ${path} | head -c ${DEFAULT_MAX_BYTES}]`;
|
||||
details = { truncation };
|
||||
} else if (truncation.truncated) {
|
||||
// Truncation occurred - build actionable notice
|
||||
const endLineDisplay =
|
||||
startLineDisplay + truncation.outputLines - 1;
|
||||
const nextOffset = endLineDisplay + 1;
|
||||
|
||||
outputText = truncation.content;
|
||||
|
||||
if (truncation.truncatedBy === "lines") {
|
||||
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines}. Use offset=${nextOffset} to continue.]`;
|
||||
} else {
|
||||
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Use offset=${nextOffset} to continue.]`;
|
||||
}
|
||||
details = { truncation };
|
||||
} else if (
|
||||
userLimitedLines !== undefined &&
|
||||
startLine + userLimitedLines < allLines.length
|
||||
) {
|
||||
// User specified limit, there's more content, but no truncation
|
||||
const remaining =
|
||||
allLines.length - (startLine + userLimitedLines);
|
||||
const nextOffset = startLine + userLimitedLines + 1;
|
||||
|
||||
outputText = truncation.content;
|
||||
outputText += `\n\n[${remaining} more lines in file. Use offset=${nextOffset} to continue.]`;
|
||||
} else {
|
||||
// No truncation, no user limit exceeded
|
||||
outputText = truncation.content;
|
||||
}
|
||||
|
||||
content = [{ type: "text", text: outputText }];
|
||||
}
|
||||
|
||||
// Check if aborted after reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
resolve({ content, details });
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
if (!aborted) {
|
||||
reject(error);
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default read tool using process.cwd() - for backwards compatibility */
|
||||
export const readTool = createReadTool(process.cwd());
|
||||
279
packages/coding-agent/src/core/tools/truncate.ts
Normal file
279
packages/coding-agent/src/core/tools/truncate.ts
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
/**
|
||||
* Shared truncation utilities for tool outputs.
|
||||
*
|
||||
* Truncation is based on two independent limits - whichever is hit first wins:
|
||||
* - Line limit (default: 2000 lines)
|
||||
* - Byte limit (default: 50KB)
|
||||
*
|
||||
* Never returns partial lines (except bash tail truncation edge case).
|
||||
*/
|
||||
|
||||
export const DEFAULT_MAX_LINES = 2000;
|
||||
export const DEFAULT_MAX_BYTES = 50 * 1024; // 50KB
|
||||
export const GREP_MAX_LINE_LENGTH = 500; // Max chars per grep match line
|
||||
|
||||
export interface TruncationResult {
|
||||
/** The truncated content */
|
||||
content: string;
|
||||
/** Whether truncation occurred */
|
||||
truncated: boolean;
|
||||
/** Which limit was hit: "lines", "bytes", or null if not truncated */
|
||||
truncatedBy: "lines" | "bytes" | null;
|
||||
/** Total number of lines in the original content */
|
||||
totalLines: number;
|
||||
/** Total number of bytes in the original content */
|
||||
totalBytes: number;
|
||||
/** Number of complete lines in the truncated output */
|
||||
outputLines: number;
|
||||
/** Number of bytes in the truncated output */
|
||||
outputBytes: number;
|
||||
/** Whether the last line was partially truncated (only for tail truncation edge case) */
|
||||
lastLinePartial: boolean;
|
||||
/** Whether the first line exceeded the byte limit (for head truncation) */
|
||||
firstLineExceedsLimit: boolean;
|
||||
/** The max lines limit that was applied */
|
||||
maxLines: number;
|
||||
/** The max bytes limit that was applied */
|
||||
maxBytes: number;
|
||||
}
|
||||
|
||||
export interface TruncationOptions {
|
||||
/** Maximum number of lines (default: 2000) */
|
||||
maxLines?: number;
|
||||
/** Maximum number of bytes (default: 50KB) */
|
||||
maxBytes?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format bytes as human-readable size.
|
||||
*/
|
||||
export function formatSize(bytes: number): string {
|
||||
if (bytes < 1024) {
|
||||
return `${bytes}B`;
|
||||
} else if (bytes < 1024 * 1024) {
|
||||
return `${(bytes / 1024).toFixed(1)}KB`;
|
||||
} else {
|
||||
return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate content from the head (keep first N lines/bytes).
|
||||
* Suitable for file reads where you want to see the beginning.
|
||||
*
|
||||
* Never returns partial lines. If first line exceeds byte limit,
|
||||
* returns empty content with firstLineExceedsLimit=true.
|
||||
*/
|
||||
export function truncateHead(
|
||||
content: string,
|
||||
options: TruncationOptions = {},
|
||||
): TruncationResult {
|
||||
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
|
||||
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
|
||||
|
||||
const totalBytes = Buffer.byteLength(content, "utf-8");
|
||||
const lines = content.split("\n");
|
||||
const totalLines = lines.length;
|
||||
|
||||
// Check if no truncation needed
|
||||
if (totalLines <= maxLines && totalBytes <= maxBytes) {
|
||||
return {
|
||||
content,
|
||||
truncated: false,
|
||||
truncatedBy: null,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: totalLines,
|
||||
outputBytes: totalBytes,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
// Check if first line alone exceeds byte limit
|
||||
const firstLineBytes = Buffer.byteLength(lines[0], "utf-8");
|
||||
if (firstLineBytes > maxBytes) {
|
||||
return {
|
||||
content: "",
|
||||
truncated: true,
|
||||
truncatedBy: "bytes",
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: 0,
|
||||
outputBytes: 0,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: true,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
// Collect complete lines that fit
|
||||
const outputLinesArr: string[] = [];
|
||||
let outputBytesCount = 0;
|
||||
let truncatedBy: "lines" | "bytes" = "lines";
|
||||
|
||||
for (let i = 0; i < lines.length && i < maxLines; i++) {
|
||||
const line = lines[i];
|
||||
const lineBytes = Buffer.byteLength(line, "utf-8") + (i > 0 ? 1 : 0); // +1 for newline
|
||||
|
||||
if (outputBytesCount + lineBytes > maxBytes) {
|
||||
truncatedBy = "bytes";
|
||||
break;
|
||||
}
|
||||
|
||||
outputLinesArr.push(line);
|
||||
outputBytesCount += lineBytes;
|
||||
}
|
||||
|
||||
// If we exited due to line limit
|
||||
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
|
||||
truncatedBy = "lines";
|
||||
}
|
||||
|
||||
const outputContent = outputLinesArr.join("\n");
|
||||
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
|
||||
|
||||
return {
|
||||
content: outputContent,
|
||||
truncated: true,
|
||||
truncatedBy,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: outputLinesArr.length,
|
||||
outputBytes: finalOutputBytes,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate content from the tail (keep last N lines/bytes).
|
||||
* Suitable for bash output where you want to see the end (errors, final results).
|
||||
*
|
||||
* May return partial first line if the last line of original content exceeds byte limit.
|
||||
*/
|
||||
export function truncateTail(
|
||||
content: string,
|
||||
options: TruncationOptions = {},
|
||||
): TruncationResult {
|
||||
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
|
||||
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
|
||||
|
||||
const totalBytes = Buffer.byteLength(content, "utf-8");
|
||||
const lines = content.split("\n");
|
||||
const totalLines = lines.length;
|
||||
|
||||
// Check if no truncation needed
|
||||
if (totalLines <= maxLines && totalBytes <= maxBytes) {
|
||||
return {
|
||||
content,
|
||||
truncated: false,
|
||||
truncatedBy: null,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: totalLines,
|
||||
outputBytes: totalBytes,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
// Work backwards from the end
|
||||
const outputLinesArr: string[] = [];
|
||||
let outputBytesCount = 0;
|
||||
let truncatedBy: "lines" | "bytes" = "lines";
|
||||
let lastLinePartial = false;
|
||||
|
||||
for (
|
||||
let i = lines.length - 1;
|
||||
i >= 0 && outputLinesArr.length < maxLines;
|
||||
i--
|
||||
) {
|
||||
const line = lines[i];
|
||||
const lineBytes =
|
||||
Buffer.byteLength(line, "utf-8") + (outputLinesArr.length > 0 ? 1 : 0); // +1 for newline
|
||||
|
||||
if (outputBytesCount + lineBytes > maxBytes) {
|
||||
truncatedBy = "bytes";
|
||||
// Edge case: if we haven't added ANY lines yet and this line exceeds maxBytes,
|
||||
// take the end of the line (partial)
|
||||
if (outputLinesArr.length === 0) {
|
||||
const truncatedLine = truncateStringToBytesFromEnd(line, maxBytes);
|
||||
outputLinesArr.unshift(truncatedLine);
|
||||
outputBytesCount = Buffer.byteLength(truncatedLine, "utf-8");
|
||||
lastLinePartial = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
outputLinesArr.unshift(line);
|
||||
outputBytesCount += lineBytes;
|
||||
}
|
||||
|
||||
// If we exited due to line limit
|
||||
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
|
||||
truncatedBy = "lines";
|
||||
}
|
||||
|
||||
const outputContent = outputLinesArr.join("\n");
|
||||
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
|
||||
|
||||
return {
|
||||
content: outputContent,
|
||||
truncated: true,
|
||||
truncatedBy,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: outputLinesArr.length,
|
||||
outputBytes: finalOutputBytes,
|
||||
lastLinePartial,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate a string to fit within a byte limit (from the end).
|
||||
* Handles multi-byte UTF-8 characters correctly.
|
||||
*/
|
||||
function truncateStringToBytesFromEnd(str: string, maxBytes: number): string {
|
||||
const buf = Buffer.from(str, "utf-8");
|
||||
if (buf.length <= maxBytes) {
|
||||
return str;
|
||||
}
|
||||
|
||||
// Start from the end, skip maxBytes back
|
||||
let start = buf.length - maxBytes;
|
||||
|
||||
// Find a valid UTF-8 boundary (start of a character)
|
||||
while (start < buf.length && (buf[start] & 0xc0) === 0x80) {
|
||||
start++;
|
||||
}
|
||||
|
||||
return buf.slice(start).toString("utf-8");
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate a single line to max characters, adding [truncated] suffix.
|
||||
* Used for grep match lines.
|
||||
*/
|
||||
export function truncateLine(
|
||||
line: string,
|
||||
maxChars: number = GREP_MAX_LINE_LENGTH,
|
||||
): { text: string; wasTruncated: boolean } {
|
||||
if (line.length <= maxChars) {
|
||||
return { text: line, wasTruncated: false };
|
||||
}
|
||||
return {
|
||||
text: `${line.slice(0, maxChars)}... [truncated]`,
|
||||
wasTruncated: true,
|
||||
};
|
||||
}
|
||||
129
packages/coding-agent/src/core/tools/write.ts
Normal file
129
packages/coding-agent/src/core/tools/write.ts
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { mkdir as fsMkdir, writeFile as fsWriteFile } from "fs/promises";
|
||||
import { dirname } from "path";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
|
||||
const writeSchema = Type.Object({
|
||||
path: Type.String({
|
||||
description: "Path to the file to write (relative or absolute)",
|
||||
}),
|
||||
content: Type.String({ description: "Content to write to the file" }),
|
||||
});
|
||||
|
||||
export type WriteToolInput = Static<typeof writeSchema>;
|
||||
|
||||
/**
|
||||
* Pluggable operations for the write tool.
|
||||
* Override these to delegate file writing to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface WriteOperations {
|
||||
/** Write content to a file */
|
||||
writeFile: (absolutePath: string, content: string) => Promise<void>;
|
||||
/** Create directory (recursively) */
|
||||
mkdir: (dir: string) => Promise<void>;
|
||||
}
|
||||
|
||||
const defaultWriteOperations: WriteOperations = {
|
||||
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
|
||||
mkdir: (dir) => fsMkdir(dir, { recursive: true }).then(() => {}),
|
||||
};
|
||||
|
||||
export interface WriteToolOptions {
|
||||
/** Custom operations for file writing. Default: local filesystem */
|
||||
operations?: WriteOperations;
|
||||
}
|
||||
|
||||
export function createWriteTool(
|
||||
cwd: string,
|
||||
options?: WriteToolOptions,
|
||||
): AgentTool<typeof writeSchema> {
|
||||
const ops = options?.operations ?? defaultWriteOperations;
|
||||
|
||||
return {
|
||||
name: "write",
|
||||
label: "write",
|
||||
description:
|
||||
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories.",
|
||||
parameters: writeSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{ path, content }: { path: string; content: string },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const absolutePath = resolveToCwd(path, cwd);
|
||||
const dir = dirname(absolutePath);
|
||||
|
||||
return new Promise<{
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let aborted = false;
|
||||
|
||||
// Set up abort handler
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
reject(new Error("Operation aborted"));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
// Perform the write operation
|
||||
(async () => {
|
||||
try {
|
||||
// Create parent directories if needed
|
||||
await ops.mkdir(dir);
|
||||
|
||||
// Check if aborted before writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Write the file
|
||||
await ops.writeFile(absolutePath, content);
|
||||
|
||||
// Check if aborted after writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Successfully wrote ${content.length} bytes to ${path}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
});
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
if (!aborted) {
|
||||
reject(error);
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default write tool using process.cwd() - for backwards compatibility */
|
||||
export const writeTool = createWriteTool(process.cwd());
|
||||
205
packages/coding-agent/src/core/vercel-ai-stream.ts
Normal file
205
packages/coding-agent/src/core/vercel-ai-stream.ts
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
import { randomUUID } from "node:crypto";
|
||||
import type { ServerResponse } from "node:http";
|
||||
import type { AgentSessionEvent } from "./agent-session.js";
|
||||
|
||||
/**
|
||||
* Write a single Vercel AI SDK v5+ SSE chunk to the response.
|
||||
* Format: `data: <JSON>\n\n`
|
||||
* For the terminal [DONE] sentinel: `data: [DONE]\n\n`
|
||||
*/
|
||||
function writeChunk(response: ServerResponse, chunk: object | string): void {
|
||||
if (response.writableEnded) return;
|
||||
const payload = typeof chunk === "string" ? chunk : JSON.stringify(chunk);
|
||||
response.write(`data: ${payload}\n\n`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the user's text from the request body.
|
||||
* Supports both useChat format ({ messages: UIMessage[] }) and simple gateway format ({ text: string }).
|
||||
*/
|
||||
export function extractUserText(body: Record<string, unknown>): string | null {
|
||||
// Simple gateway format
|
||||
if (typeof body.text === "string" && body.text.trim()) {
|
||||
return body.text;
|
||||
}
|
||||
// Convenience format
|
||||
if (typeof body.prompt === "string" && body.prompt.trim()) {
|
||||
return body.prompt;
|
||||
}
|
||||
// Vercel AI SDK useChat format - extract last user message
|
||||
if (Array.isArray(body.messages)) {
|
||||
for (let i = body.messages.length - 1; i >= 0; i--) {
|
||||
const msg = body.messages[i] as Record<string, unknown>;
|
||||
if (msg.role !== "user") continue;
|
||||
// v5+ format with parts array
|
||||
if (Array.isArray(msg.parts)) {
|
||||
for (const part of msg.parts as Array<Record<string, unknown>>) {
|
||||
if (part.type === "text" && typeof part.text === "string") {
|
||||
return part.text;
|
||||
}
|
||||
}
|
||||
}
|
||||
// v4 format with content string
|
||||
if (typeof msg.content === "string" && msg.content.trim()) {
|
||||
return msg.content;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AgentSessionEvent listener that translates events to Vercel AI SDK v5+ SSE
|
||||
* chunks and writes them to the HTTP response.
|
||||
*
|
||||
* Returns the listener function. The caller is responsible for subscribing/unsubscribing.
|
||||
*/
|
||||
export function createVercelStreamListener(
|
||||
response: ServerResponse,
|
||||
messageId?: string,
|
||||
): (event: AgentSessionEvent) => void {
|
||||
// Gate: only forward events within a single prompt's agent_start -> agent_end lifecycle.
|
||||
// handleChat now subscribes this listener immediately before the queued prompt starts,
|
||||
// so these guards only need to bound the stream to that prompt's event span.
|
||||
let active = false;
|
||||
const msgId = messageId ?? randomUUID();
|
||||
|
||||
return (event: AgentSessionEvent) => {
|
||||
if (response.writableEnded) return;
|
||||
|
||||
// Activate on our agent_start, deactivate on agent_end
|
||||
if (event.type === "agent_start") {
|
||||
if (!active) {
|
||||
active = true;
|
||||
writeChunk(response, { type: "start", messageId: msgId });
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (event.type === "agent_end") {
|
||||
active = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Drop events that don't belong to our message
|
||||
if (!active) return;
|
||||
|
||||
switch (event.type) {
|
||||
case "turn_start":
|
||||
writeChunk(response, { type: "start-step" });
|
||||
return;
|
||||
|
||||
case "message_update": {
|
||||
const inner = event.assistantMessageEvent;
|
||||
switch (inner.type) {
|
||||
case "text_start":
|
||||
writeChunk(response, {
|
||||
type: "text-start",
|
||||
id: `text_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
case "text_delta":
|
||||
writeChunk(response, {
|
||||
type: "text-delta",
|
||||
id: `text_${inner.contentIndex}`,
|
||||
delta: inner.delta,
|
||||
});
|
||||
return;
|
||||
case "text_end":
|
||||
writeChunk(response, {
|
||||
type: "text-end",
|
||||
id: `text_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
case "toolcall_start": {
|
||||
const content = inner.partial.content[inner.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
writeChunk(response, {
|
||||
type: "tool-input-start",
|
||||
toolCallId: content.id,
|
||||
toolName: content.name,
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
case "toolcall_delta": {
|
||||
const content = inner.partial.content[inner.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
writeChunk(response, {
|
||||
type: "tool-input-delta",
|
||||
toolCallId: content.id,
|
||||
inputTextDelta: inner.delta,
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
case "toolcall_end":
|
||||
writeChunk(response, {
|
||||
type: "tool-input-available",
|
||||
toolCallId: inner.toolCall.id,
|
||||
toolName: inner.toolCall.name,
|
||||
input: inner.toolCall.arguments,
|
||||
});
|
||||
return;
|
||||
case "thinking_start":
|
||||
writeChunk(response, {
|
||||
type: "reasoning-start",
|
||||
id: `reasoning_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
case "thinking_delta":
|
||||
writeChunk(response, {
|
||||
type: "reasoning-delta",
|
||||
id: `reasoning_${inner.contentIndex}`,
|
||||
delta: inner.delta,
|
||||
});
|
||||
return;
|
||||
case "thinking_end":
|
||||
writeChunk(response, {
|
||||
type: "reasoning-end",
|
||||
id: `reasoning_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
case "turn_end":
|
||||
writeChunk(response, { type: "finish-step" });
|
||||
return;
|
||||
|
||||
case "tool_execution_end":
|
||||
writeChunk(response, {
|
||||
type: "tool-output-available",
|
||||
toolCallId: event.toolCallId,
|
||||
output: event.result,
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Write the terminal finish sequence and end the response.
|
||||
*/
|
||||
export function finishVercelStream(
|
||||
response: ServerResponse,
|
||||
finishReason: string = "stop",
|
||||
): void {
|
||||
if (response.writableEnded) return;
|
||||
writeChunk(response, { type: "finish", finishReason });
|
||||
writeChunk(response, "[DONE]");
|
||||
response.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* Write an error chunk and end the response.
|
||||
*/
|
||||
export function errorVercelStream(
|
||||
response: ServerResponse,
|
||||
errorText: string,
|
||||
): void {
|
||||
if (response.writableEnded) return;
|
||||
writeChunk(response, { type: "error", errorText });
|
||||
writeChunk(response, "[DONE]");
|
||||
response.end();
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue