mirror of
https://github.com/harivansh-afk/sandbox-agent.git
synced 2026-04-16 04:02:01 +00:00
feat: add universal hooks support for sandbox-agent
Implements sandbox-agent-managed hooks that work universally across all agents (Claude, Codex, OpenCode, Amp, Mock). Hooks are shell commands executed at specific lifecycle points in a session. ## Hook Types - on_session_start - Executed when a session is created - on_session_end - Executed when a session terminates - on_message_start - Executed before processing each message - on_message_end - Executed after each message is fully processed ## Features - Hooks are configured via the API when creating sessions - Environment variables provide context (SANDBOX_SESSION_ID, SANDBOX_AGENT, etc.) - Configurable timeout per hook - continue_on_failure option to control execution flow - Working directory support for hook execution ## API Changes - CreateSessionRequest now accepts optional 'hooks' and 'workingDir' fields - HooksConfig and HookDefinition schemas added to OpenAPI spec ## Testing - 8 unit tests for hook execution - 9 integration tests using mock agent with snapshot testing - Tests cover all lifecycle hooks, multiple hooks, failure handling, and env vars
This commit is contained in:
parent
cacb63ef17
commit
e84967f916
15 changed files with 1454 additions and 63 deletions
560
server/packages/sandbox-agent/src/hooks.rs
Normal file
560
server/packages/sandbox-agent/src/hooks.rs
Normal file
|
|
@ -0,0 +1,560 @@
|
|||
//! Universal hooks support for sandbox-agent.
|
||||
//!
|
||||
//! Hooks are shell commands executed at specific lifecycle points in a session.
|
||||
//! They are managed by sandbox-agent itself (not the underlying agent) and work
|
||||
//! universally across all agents (Claude, Codex, OpenCode, Amp, Mock).
|
||||
//!
|
||||
//! # Hook Types
|
||||
//!
|
||||
//! - `on_session_start` - Executed when a session is created
|
||||
//! - `on_session_end` - Executed when a session terminates
|
||||
//! - `on_message_start` - Executed before processing each message
|
||||
//! - `on_message_end` - Executed after each message is fully processed
|
||||
//!
|
||||
//! # Environment Variables
|
||||
//!
|
||||
//! Hooks receive context via environment variables:
|
||||
//! - `SANDBOX_SESSION_ID` - The session identifier
|
||||
//! - `SANDBOX_AGENT` - The agent type (e.g., "claude", "codex", "mock")
|
||||
//! - `SANDBOX_AGENT_MODE` - The agent mode
|
||||
//! - `SANDBOX_HOOK_TYPE` - The hook type being executed
|
||||
//! - `SANDBOX_MESSAGE` - The message content (for message hooks)
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::Duration;
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
/// Default timeout for hook execution in seconds.
|
||||
const DEFAULT_HOOK_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Maximum output size to capture from hooks (64KB).
|
||||
const MAX_OUTPUT_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// Configuration for hooks in a session.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HooksConfig {
|
||||
/// Hooks to run when a session starts.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub on_session_start: Vec<HookDefinition>,
|
||||
|
||||
/// Hooks to run when a session ends.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub on_session_end: Vec<HookDefinition>,
|
||||
|
||||
/// Hooks to run before processing each message.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub on_message_start: Vec<HookDefinition>,
|
||||
|
||||
/// Hooks to run after each message is fully processed.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub on_message_end: Vec<HookDefinition>,
|
||||
}
|
||||
|
||||
impl HooksConfig {
|
||||
/// Returns true if no hooks are configured.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.on_session_start.is_empty()
|
||||
&& self.on_session_end.is_empty()
|
||||
&& self.on_message_start.is_empty()
|
||||
&& self.on_message_end.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Definition of a single hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HookDefinition {
|
||||
/// Shell command to execute.
|
||||
pub command: String,
|
||||
|
||||
/// Timeout in seconds. Defaults to 30 seconds.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub timeout_secs: Option<u64>,
|
||||
|
||||
/// Working directory for the command. Defaults to the session's working directory.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub working_dir: Option<String>,
|
||||
|
||||
/// Whether to continue if the hook fails. Defaults to true.
|
||||
#[serde(default = "default_continue_on_failure")]
|
||||
pub continue_on_failure: bool,
|
||||
}
|
||||
|
||||
fn default_continue_on_failure() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Type of hook being executed.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookType {
|
||||
SessionStart,
|
||||
SessionEnd,
|
||||
MessageStart,
|
||||
MessageEnd,
|
||||
}
|
||||
|
||||
impl HookType {
|
||||
/// Returns the string representation used in environment variables.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
HookType::SessionStart => "session_start",
|
||||
HookType::SessionEnd => "session_end",
|
||||
HookType::MessageStart => "message_start",
|
||||
HookType::MessageEnd => "message_end",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context passed to hooks via environment variables.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookContext {
|
||||
pub session_id: String,
|
||||
pub agent: String,
|
||||
pub agent_mode: String,
|
||||
pub hook_type: HookType,
|
||||
pub message: Option<String>,
|
||||
pub working_dir: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of executing a single hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HookResult {
|
||||
pub command: String,
|
||||
pub success: bool,
|
||||
pub exit_code: Option<i32>,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub duration_ms: u64,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
/// Result of executing all hooks for a lifecycle event.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HooksExecutionResult {
|
||||
pub hook_type: String,
|
||||
pub results: Vec<HookResult>,
|
||||
pub all_succeeded: bool,
|
||||
pub should_continue: bool,
|
||||
}
|
||||
|
||||
/// Executes hooks for a given lifecycle event.
|
||||
pub async fn execute_hooks(
|
||||
hooks: &[HookDefinition],
|
||||
context: &HookContext,
|
||||
) -> HooksExecutionResult {
|
||||
let mut results = Vec::new();
|
||||
let mut all_succeeded = true;
|
||||
let mut should_continue = true;
|
||||
|
||||
for hook in hooks {
|
||||
let result = execute_single_hook(hook, context).await;
|
||||
|
||||
if !result.success {
|
||||
all_succeeded = false;
|
||||
if !hook.continue_on_failure {
|
||||
should_continue = false;
|
||||
}
|
||||
}
|
||||
|
||||
results.push(result);
|
||||
|
||||
if !should_continue {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
HooksExecutionResult {
|
||||
hook_type: context.hook_type.as_str().to_string(),
|
||||
results,
|
||||
all_succeeded,
|
||||
should_continue,
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes a single hook command.
|
||||
async fn execute_single_hook(hook: &HookDefinition, context: &HookContext) -> HookResult {
|
||||
let start = std::time::Instant::now();
|
||||
let timeout_duration = Duration::from_secs(
|
||||
hook.timeout_secs.unwrap_or(DEFAULT_HOOK_TIMEOUT_SECS),
|
||||
);
|
||||
|
||||
// Determine working directory
|
||||
let working_dir = hook
|
||||
.working_dir
|
||||
.clone()
|
||||
.or_else(|| context.working_dir.clone());
|
||||
|
||||
info!(
|
||||
command = %hook.command,
|
||||
hook_type = %context.hook_type.as_str(),
|
||||
session_id = %context.session_id,
|
||||
"Executing hook"
|
||||
);
|
||||
|
||||
// Build environment variables
|
||||
let mut env: HashMap<String, String> = std::env::vars().collect();
|
||||
env.insert("SANDBOX_SESSION_ID".to_string(), context.session_id.clone());
|
||||
env.insert("SANDBOX_AGENT".to_string(), context.agent.clone());
|
||||
env.insert("SANDBOX_AGENT_MODE".to_string(), context.agent_mode.clone());
|
||||
env.insert("SANDBOX_HOOK_TYPE".to_string(), context.hook_type.as_str().to_string());
|
||||
if let Some(message) = &context.message {
|
||||
env.insert("SANDBOX_MESSAGE".to_string(), message.clone());
|
||||
}
|
||||
|
||||
// Clone values for the blocking task
|
||||
let command = hook.command.clone();
|
||||
let command_for_result = hook.command.clone();
|
||||
|
||||
// Execute in a blocking task with timeout
|
||||
let execution = tokio::task::spawn_blocking(move || {
|
||||
let mut cmd = Command::new("sh");
|
||||
cmd.arg("-c").arg(&command);
|
||||
cmd.envs(&env);
|
||||
cmd.stdin(Stdio::null());
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
|
||||
if let Some(dir) = working_dir.as_ref() {
|
||||
if Path::new(dir).exists() {
|
||||
cmd.current_dir(dir);
|
||||
}
|
||||
}
|
||||
|
||||
let mut child = cmd.spawn()?;
|
||||
|
||||
let mut stdout = String::new();
|
||||
let mut stderr = String::new();
|
||||
|
||||
if let Some(ref mut handle) = child.stdout {
|
||||
let mut buf = vec![0u8; MAX_OUTPUT_SIZE];
|
||||
let n = handle.read(&mut buf).unwrap_or(0);
|
||||
stdout = String::from_utf8_lossy(&buf[..n]).to_string();
|
||||
}
|
||||
|
||||
if let Some(ref mut handle) = child.stderr {
|
||||
let mut buf = vec![0u8; MAX_OUTPUT_SIZE];
|
||||
let n = handle.read(&mut buf).unwrap_or(0);
|
||||
stderr = String::from_utf8_lossy(&buf[..n]).to_string();
|
||||
}
|
||||
|
||||
let status = child.wait()?;
|
||||
Ok::<_, std::io::Error>((status, stdout, stderr))
|
||||
});
|
||||
|
||||
match timeout(timeout_duration, execution).await {
|
||||
Ok(Ok(Ok((status, stdout, stderr)))) => {
|
||||
let exit_code = status.code();
|
||||
let success = status.success();
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
debug!(
|
||||
command = %command_for_result,
|
||||
success = %success,
|
||||
exit_code = ?exit_code,
|
||||
duration_ms = %duration_ms,
|
||||
"Hook completed"
|
||||
);
|
||||
|
||||
if !success {
|
||||
warn!(
|
||||
command = %command_for_result,
|
||||
exit_code = ?exit_code,
|
||||
stderr = %stderr,
|
||||
"Hook failed"
|
||||
);
|
||||
}
|
||||
|
||||
HookResult {
|
||||
command: command_for_result,
|
||||
success,
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
duration_ms,
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
Ok(Ok(Err(err))) => {
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
error!(
|
||||
command = %command_for_result,
|
||||
error = %err,
|
||||
"Hook execution error"
|
||||
);
|
||||
HookResult {
|
||||
command: command_for_result,
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: err.to_string(),
|
||||
duration_ms,
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
error!(
|
||||
command = %command_for_result,
|
||||
error = %err,
|
||||
"Hook task join error"
|
||||
);
|
||||
HookResult {
|
||||
command: command_for_result,
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: err.to_string(),
|
||||
duration_ms,
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
warn!(
|
||||
command = %command_for_result,
|
||||
timeout_secs = %timeout_duration.as_secs(),
|
||||
"Hook timed out"
|
||||
);
|
||||
HookResult {
|
||||
command: command_for_result,
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: format!("Hook timed out after {} seconds", timeout_duration.as_secs()),
|
||||
duration_ms,
|
||||
timed_out: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simple_hook_execution() {
|
||||
let hook = HookDefinition {
|
||||
command: "echo 'hello world'".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
};
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "test-session".to_string(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode: "default".to_string(),
|
||||
hook_type: HookType::SessionStart,
|
||||
message: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_single_hook(&hook, &context).await;
|
||||
assert!(result.success);
|
||||
assert_eq!(result.exit_code, Some(0));
|
||||
assert!(result.stdout.contains("hello world"));
|
||||
assert!(!result.timed_out);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hook_with_env_vars() {
|
||||
let hook = HookDefinition {
|
||||
command: "echo $SANDBOX_SESSION_ID $SANDBOX_AGENT".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
};
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "my-session-123".to_string(),
|
||||
agent: "codex".to_string(),
|
||||
agent_mode: "auto".to_string(),
|
||||
hook_type: HookType::MessageStart,
|
||||
message: Some("test message".to_string()),
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_single_hook(&hook, &context).await;
|
||||
assert!(result.success);
|
||||
assert!(result.stdout.contains("my-session-123"));
|
||||
assert!(result.stdout.contains("codex"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hook_failure() {
|
||||
let hook = HookDefinition {
|
||||
command: "exit 1".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
};
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "test".to_string(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode: "default".to_string(),
|
||||
hook_type: HookType::SessionEnd,
|
||||
message: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_single_hook(&hook, &context).await;
|
||||
assert!(!result.success);
|
||||
assert_eq!(result.exit_code, Some(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hook_timeout() {
|
||||
let hook = HookDefinition {
|
||||
command: "sleep 10".to_string(),
|
||||
timeout_secs: Some(1),
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
};
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "test".to_string(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode: "default".to_string(),
|
||||
hook_type: HookType::MessageEnd,
|
||||
message: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_single_hook(&hook, &context).await;
|
||||
assert!(!result.success);
|
||||
assert!(result.timed_out);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_hooks_all_succeed() {
|
||||
let hooks = vec![
|
||||
HookDefinition {
|
||||
command: "echo 'first'".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
},
|
||||
HookDefinition {
|
||||
command: "echo 'second'".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
},
|
||||
];
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "test".to_string(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode: "default".to_string(),
|
||||
hook_type: HookType::SessionStart,
|
||||
message: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_hooks(&hooks, &context).await;
|
||||
assert!(result.all_succeeded);
|
||||
assert!(result.should_continue);
|
||||
assert_eq!(result.results.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hooks_stop_on_failure_when_configured() {
|
||||
let hooks = vec![
|
||||
HookDefinition {
|
||||
command: "echo 'first'".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
},
|
||||
HookDefinition {
|
||||
command: "exit 1".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: false, // Don't continue on failure
|
||||
},
|
||||
HookDefinition {
|
||||
command: "echo 'third'".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
},
|
||||
];
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "test".to_string(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode: "default".to_string(),
|
||||
hook_type: HookType::MessageStart,
|
||||
message: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_hooks(&hooks, &context).await;
|
||||
assert!(!result.all_succeeded);
|
||||
assert!(!result.should_continue);
|
||||
// Third hook should not have been executed
|
||||
assert_eq!(result.results.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hooks_continue_on_failure_when_configured() {
|
||||
let hooks = vec![
|
||||
HookDefinition {
|
||||
command: "exit 1".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true, // Continue despite failure
|
||||
},
|
||||
HookDefinition {
|
||||
command: "echo 'second'".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
},
|
||||
];
|
||||
|
||||
let context = HookContext {
|
||||
session_id: "test".to_string(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode: "default".to_string(),
|
||||
hook_type: HookType::SessionEnd,
|
||||
message: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let result = execute_hooks(&hooks, &context).await;
|
||||
assert!(!result.all_succeeded);
|
||||
assert!(result.should_continue);
|
||||
// Both hooks should have been executed
|
||||
assert_eq!(result.results.len(), 2);
|
||||
assert!(result.results[1].success);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hooks_config_is_empty() {
|
||||
let config = HooksConfig::default();
|
||||
assert!(config.is_empty());
|
||||
|
||||
let config = HooksConfig {
|
||||
on_session_start: vec![HookDefinition {
|
||||
command: "echo test".to_string(),
|
||||
timeout_secs: None,
|
||||
working_dir: None,
|
||||
continue_on_failure: true,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
//! Sandbox agent core utilities.
|
||||
|
||||
pub mod credentials;
|
||||
pub mod hooks;
|
||||
mod agent_server_logs;
|
||||
pub mod router;
|
||||
pub mod telemetry;
|
||||
|
|
|
|||
|
|
@ -492,6 +492,8 @@ fn run_sessions(command: &SessionsCommand, cli: &Cli) -> Result<(), CliError> {
|
|||
model: args.model.clone(),
|
||||
variant: args.variant.clone(),
|
||||
agent_version: args.agent_version.clone(),
|
||||
hooks: None,
|
||||
working_dir: None,
|
||||
};
|
||||
let path = format!("{API_PREFIX}/sessions/{}", args.session_id);
|
||||
let response = ctx.post(&path, &body)?;
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ use tokio_stream::wrappers::BroadcastStream;
|
|||
use tower_http::trace::TraceLayer;
|
||||
use utoipa::{Modify, OpenApi, ToSchema};
|
||||
|
||||
use crate::hooks::{execute_hooks, HookContext, HookDefinition, HooksConfig, HookType};
|
||||
use crate::ui;
|
||||
use sandbox_agent_agent_management::agents::{
|
||||
AgentError as ManagerError, AgentId, AgentManager, InstallOptions, SpawnOptions, StreamingSpawn,
|
||||
|
|
@ -205,7 +206,9 @@ pub async fn shutdown_servers(state: &Arc<AppState>) {
|
|||
PermissionReply,
|
||||
ProblemDetails,
|
||||
ErrorType,
|
||||
AgentError
|
||||
AgentError,
|
||||
HooksConfig,
|
||||
HookDefinition
|
||||
)
|
||||
),
|
||||
tags(
|
||||
|
|
@ -274,6 +277,10 @@ struct SessionState {
|
|||
claude_message_counter: u64,
|
||||
pending_assistant_native_ids: VecDeque<String>,
|
||||
pending_assistant_counter: u64,
|
||||
/// Hooks configuration for this session.
|
||||
hooks: HooksConfig,
|
||||
/// Working directory for hooks execution.
|
||||
working_dir: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -332,9 +339,23 @@ impl SessionState {
|
|||
claude_message_counter: 0,
|
||||
pending_assistant_native_ids: VecDeque::new(),
|
||||
pending_assistant_counter: 0,
|
||||
hooks: request.hooks.clone().unwrap_or_default(),
|
||||
working_dir: request.working_dir.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a hook context for this session.
|
||||
fn hook_context(&self, hook_type: HookType, message: Option<String>) -> HookContext {
|
||||
HookContext {
|
||||
session_id: self.session_id.clone(),
|
||||
agent: self.agent.as_str().to_string(),
|
||||
agent_mode: self.agent_mode.clone(),
|
||||
hook_type,
|
||||
message,
|
||||
working_dir: self.working_dir.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_pending_assistant_native_id(&mut self) -> String {
|
||||
self.pending_assistant_counter += 1;
|
||||
format!(
|
||||
|
|
@ -1468,6 +1489,18 @@ impl SessionManager {
|
|||
logs.read_stderr()
|
||||
}
|
||||
|
||||
/// Gets a snapshot of the hooks configuration for a session.
|
||||
async fn get_hooks_snapshot(&self, session_id: &str) -> Option<HooksConfig> {
|
||||
let sessions = self.sessions.lock().await;
|
||||
Self::session_ref(&sessions, session_id).map(|s| s.hooks.clone())
|
||||
}
|
||||
|
||||
/// Gets the working directory for a session.
|
||||
async fn get_working_dir(&self, session_id: &str) -> Option<String> {
|
||||
let sessions = self.sessions.lock().await;
|
||||
Self::session_ref(&sessions, session_id).and_then(|s| s.working_dir.clone())
|
||||
}
|
||||
|
||||
async fn create_session(
|
||||
self: &Arc<Self>,
|
||||
session_id: String,
|
||||
|
|
@ -1553,6 +1586,18 @@ impl SessionManager {
|
|||
session.record_conversions(vec![native_started]);
|
||||
}
|
||||
|
||||
// Execute on_session_start hooks
|
||||
if !session.hooks.on_session_start.is_empty() {
|
||||
let context = session.hook_context(HookType::SessionStart, None);
|
||||
let hooks_result = execute_hooks(&session.hooks.on_session_start, &context).await;
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
all_succeeded = %hooks_result.all_succeeded,
|
||||
results = ?hooks_result.results.len(),
|
||||
"Executed on_session_start hooks"
|
||||
);
|
||||
}
|
||||
|
||||
let native_session_id = session.native_session_id.clone();
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
sessions.push(session);
|
||||
|
|
@ -1599,8 +1644,31 @@ impl SessionManager {
|
|||
) -> Result<(), SandboxError> {
|
||||
// Use allow_ended=true and do explicit check to allow resumable agents
|
||||
let session_snapshot = self.session_snapshot_for_message(&session_id).await?;
|
||||
|
||||
// Execute on_message_start hooks
|
||||
let hooks_snapshot = self.get_hooks_snapshot(&session_id).await;
|
||||
if let Some(ref hooks) = hooks_snapshot {
|
||||
if !hooks.on_message_start.is_empty() {
|
||||
let context = HookContext {
|
||||
session_id: session_id.clone(),
|
||||
agent: session_snapshot.agent.as_str().to_string(),
|
||||
agent_mode: session_snapshot.agent_mode.clone(),
|
||||
hook_type: HookType::MessageStart,
|
||||
message: Some(message.clone()),
|
||||
working_dir: self.get_working_dir(&session_id).await,
|
||||
};
|
||||
let hooks_result = execute_hooks(&hooks.on_message_start, &context).await;
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
all_succeeded = %hooks_result.all_succeeded,
|
||||
results = ?hooks_result.results.len(),
|
||||
"Executed on_message_start hooks"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if session_snapshot.agent == AgentId::Mock {
|
||||
self.send_mock_message(session_id, message).await?;
|
||||
self.send_mock_message(session_id, message, hooks_snapshot).await?;
|
||||
return Ok(());
|
||||
}
|
||||
if matches!(session_snapshot.agent, AgentId::Claude | AgentId::Amp) {
|
||||
|
|
@ -1666,9 +1734,19 @@ impl SessionManager {
|
|||
}
|
||||
|
||||
let manager = Arc::clone(self);
|
||||
let agent_mode = session_snapshot.agent_mode.clone();
|
||||
let working_dir = self.get_working_dir(&session_id).await;
|
||||
tokio::spawn(async move {
|
||||
manager
|
||||
.consume_spawn(session_id, agent_id, spawn_result, initial_input)
|
||||
.consume_spawn(
|
||||
session_id,
|
||||
agent_id,
|
||||
spawn_result,
|
||||
initial_input,
|
||||
hooks_snapshot,
|
||||
agent_mode,
|
||||
working_dir,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
|
|
@ -1708,42 +1786,78 @@ impl SessionManager {
|
|||
}
|
||||
|
||||
async fn terminate_session(&self, session_id: String) -> Result<(), SandboxError> {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
let session = Self::session_mut(&mut sessions, &session_id).ok_or_else(|| {
|
||||
SandboxError::SessionNotFound {
|
||||
session_id: session_id.clone(),
|
||||
let hooks_to_run: Option<(HooksConfig, AgentId, String, Option<String>)>;
|
||||
let agent: AgentId;
|
||||
let native_session_id: Option<String>;
|
||||
{
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
let session = Self::session_mut(&mut sessions, &session_id).ok_or_else(|| {
|
||||
SandboxError::SessionNotFound {
|
||||
session_id: session_id.clone(),
|
||||
}
|
||||
})?;
|
||||
if session.ended {
|
||||
return Ok(());
|
||||
}
|
||||
})?;
|
||||
if session.ended {
|
||||
return Ok(());
|
||||
// Capture hooks before marking ended
|
||||
hooks_to_run = if !session.hooks.on_session_end.is_empty() {
|
||||
Some((
|
||||
session.hooks.clone(),
|
||||
session.agent,
|
||||
session.agent_mode.clone(),
|
||||
session.working_dir.clone(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
session.mark_ended(
|
||||
None,
|
||||
"terminated by daemon".to_string(),
|
||||
SessionEndReason::Terminated,
|
||||
TerminatedBy::Daemon,
|
||||
);
|
||||
let ended = EventConversion::new(
|
||||
UniversalEventType::SessionEnded,
|
||||
UniversalEventData::SessionEnded(SessionEndedData {
|
||||
reason: SessionEndReason::Terminated,
|
||||
terminated_by: TerminatedBy::Daemon,
|
||||
message: None,
|
||||
exit_code: None,
|
||||
stderr: None,
|
||||
}),
|
||||
)
|
||||
.synthetic()
|
||||
.with_native_session(session.native_session_id.clone());
|
||||
session.record_conversions(vec![ended]);
|
||||
agent = session.agent;
|
||||
native_session_id = session.native_session_id.clone();
|
||||
}
|
||||
session.mark_ended(
|
||||
None,
|
||||
"terminated by daemon".to_string(),
|
||||
SessionEndReason::Terminated,
|
||||
TerminatedBy::Daemon,
|
||||
);
|
||||
let ended = EventConversion::new(
|
||||
UniversalEventType::SessionEnded,
|
||||
UniversalEventData::SessionEnded(SessionEndedData {
|
||||
reason: SessionEndReason::Terminated,
|
||||
terminated_by: TerminatedBy::Daemon,
|
||||
message: None,
|
||||
exit_code: None,
|
||||
stderr: None,
|
||||
}),
|
||||
)
|
||||
.synthetic()
|
||||
.with_native_session(session.native_session_id.clone());
|
||||
session.record_conversions(vec![ended]);
|
||||
let agent = session.agent;
|
||||
let native_session_id = session.native_session_id.clone();
|
||||
drop(sessions);
|
||||
|
||||
if agent == AgentId::Opencode || agent == AgentId::Codex {
|
||||
self.server_manager
|
||||
.unregister_session(agent, &session_id, native_session_id.as_deref())
|
||||
.await;
|
||||
}
|
||||
|
||||
// Execute on_session_end hooks (outside the lock)
|
||||
if let Some((hooks, agent, agent_mode, working_dir)) = hooks_to_run {
|
||||
let context = HookContext {
|
||||
session_id: session_id.clone(),
|
||||
agent: agent.as_str().to_string(),
|
||||
agent_mode,
|
||||
hook_type: HookType::SessionEnd,
|
||||
message: None,
|
||||
working_dir,
|
||||
};
|
||||
let hooks_result = execute_hooks(&hooks.on_session_end, &context).await;
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
all_succeeded = %hooks_result.all_succeeded,
|
||||
results = ?hooks_result.results.len(),
|
||||
"Executed on_session_end hooks"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -2191,8 +2305,9 @@ impl SessionManager {
|
|||
self: &Arc<Self>,
|
||||
session_id: String,
|
||||
message: String,
|
||||
hooks: Option<HooksConfig>,
|
||||
) -> Result<(), SandboxError> {
|
||||
let prefix = {
|
||||
let (prefix, agent_mode, working_dir) = {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
let session = Self::session_mut(&mut sessions, &session_id).ok_or_else(|| {
|
||||
SandboxError::SessionNotFound {
|
||||
|
|
@ -2203,7 +2318,11 @@ impl SessionManager {
|
|||
return Err(err);
|
||||
}
|
||||
session.mock_sequence = session.mock_sequence.saturating_add(1);
|
||||
format!("mock_{}", session.mock_sequence)
|
||||
(
|
||||
format!("mock_{}", session.mock_sequence),
|
||||
session.agent_mode.clone(),
|
||||
session.working_dir.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let mut conversions = Vec::new();
|
||||
|
|
@ -2215,7 +2334,9 @@ impl SessionManager {
|
|||
|
||||
let manager = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
manager.emit_mock_events(session_id, conversions).await;
|
||||
manager
|
||||
.emit_mock_events(session_id, conversions, hooks, agent_mode, working_dir)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
|
|
@ -2225,6 +2346,9 @@ impl SessionManager {
|
|||
self: Arc<Self>,
|
||||
session_id: String,
|
||||
conversions: Vec<EventConversion>,
|
||||
hooks: Option<HooksConfig>,
|
||||
agent_mode: String,
|
||||
working_dir: Option<String>,
|
||||
) {
|
||||
for conversion in conversions {
|
||||
if self
|
||||
|
|
@ -2236,6 +2360,27 @@ impl SessionManager {
|
|||
}
|
||||
sleep(Duration::from_millis(MOCK_EVENT_DELAY_MS)).await;
|
||||
}
|
||||
|
||||
// Execute on_message_end hooks
|
||||
if let Some(ref hooks) = hooks {
|
||||
if !hooks.on_message_end.is_empty() {
|
||||
let context = HookContext {
|
||||
session_id: session_id.clone(),
|
||||
agent: "mock".to_string(),
|
||||
agent_mode,
|
||||
hook_type: HookType::MessageEnd,
|
||||
message: None,
|
||||
working_dir,
|
||||
};
|
||||
let hooks_result = execute_hooks(&hooks.on_message_end, &context).await;
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
all_succeeded = %hooks_result.all_succeeded,
|
||||
results = ?hooks_result.results.len(),
|
||||
"Executed on_message_end hooks"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn consume_spawn(
|
||||
|
|
@ -2244,6 +2389,9 @@ impl SessionManager {
|
|||
agent: AgentId,
|
||||
spawn: StreamingSpawn,
|
||||
initial_input: Option<String>,
|
||||
hooks: Option<HooksConfig>,
|
||||
agent_mode: String,
|
||||
working_dir: Option<String>,
|
||||
) {
|
||||
let StreamingSpawn {
|
||||
mut child,
|
||||
|
|
@ -2441,6 +2589,27 @@ impl SessionManager {
|
|||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Execute on_message_end hooks
|
||||
if let Some(ref hooks) = hooks {
|
||||
if !hooks.on_message_end.is_empty() {
|
||||
let context = HookContext {
|
||||
session_id: session_id.clone(),
|
||||
agent: agent.as_str().to_string(),
|
||||
agent_mode,
|
||||
hook_type: HookType::MessageEnd,
|
||||
message: None,
|
||||
working_dir,
|
||||
};
|
||||
let hooks_result = execute_hooks(&hooks.on_message_end, &context).await;
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
all_succeeded = %hooks_result.all_succeeded,
|
||||
results = ?hooks_result.results.len(),
|
||||
"Executed on_message_end hooks"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_conversions(
|
||||
|
|
@ -2565,37 +2734,81 @@ impl SessionManager {
|
|||
terminated_by: TerminatedBy,
|
||||
stderr: Option<StderrOutput>,
|
||||
) {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||
if session.ended {
|
||||
return;
|
||||
}
|
||||
session.mark_ended(
|
||||
exit_code,
|
||||
message.to_string(),
|
||||
reason.clone(),
|
||||
terminated_by.clone(),
|
||||
);
|
||||
let (error_message, error_exit_code, error_stderr) =
|
||||
if reason == SessionEndReason::Error {
|
||||
(Some(message.to_string()), exit_code, stderr)
|
||||
let hooks_to_run: Option<(HooksConfig, String, Option<String>)>;
|
||||
{
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||
if session.ended {
|
||||
return;
|
||||
}
|
||||
// Capture hooks info before marking as ended
|
||||
hooks_to_run = if !session.hooks.on_session_end.is_empty() {
|
||||
Some((
|
||||
session.hooks.clone(),
|
||||
session.agent_mode.clone(),
|
||||
session.working_dir.clone(),
|
||||
))
|
||||
} else {
|
||||
(None, None, None)
|
||||
None
|
||||
};
|
||||
let ended = EventConversion::new(
|
||||
UniversalEventType::SessionEnded,
|
||||
UniversalEventData::SessionEnded(SessionEndedData {
|
||||
reason,
|
||||
terminated_by,
|
||||
message: error_message,
|
||||
exit_code: error_exit_code,
|
||||
stderr: error_stderr,
|
||||
}),
|
||||
)
|
||||
.synthetic()
|
||||
.with_native_session(session.native_session_id.clone());
|
||||
session.record_conversions(vec![ended]);
|
||||
session.mark_ended(
|
||||
exit_code,
|
||||
message.to_string(),
|
||||
reason.clone(),
|
||||
terminated_by.clone(),
|
||||
);
|
||||
let (error_message, error_exit_code, error_stderr) =
|
||||
if reason == SessionEndReason::Error {
|
||||
(Some(message.to_string()), exit_code, stderr)
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
let ended = EventConversion::new(
|
||||
UniversalEventType::SessionEnded,
|
||||
UniversalEventData::SessionEnded(SessionEndedData {
|
||||
reason,
|
||||
terminated_by,
|
||||
message: error_message,
|
||||
exit_code: error_exit_code,
|
||||
stderr: error_stderr,
|
||||
}),
|
||||
)
|
||||
.synthetic()
|
||||
.with_native_session(session.native_session_id.clone());
|
||||
session.record_conversions(vec![ended]);
|
||||
} else {
|
||||
hooks_to_run = None;
|
||||
}
|
||||
}
|
||||
|
||||
// Execute on_session_end hooks (outside the lock)
|
||||
if let Some((hooks, agent_mode, working_dir)) = hooks_to_run {
|
||||
let context = HookContext {
|
||||
session_id: session_id.to_string(),
|
||||
agent: self
|
||||
.get_session_agent(session_id)
|
||||
.await
|
||||
.map(|a| a.as_str().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
agent_mode,
|
||||
hook_type: HookType::SessionEnd,
|
||||
message: None,
|
||||
working_dir,
|
||||
};
|
||||
let hooks_result = execute_hooks(&hooks.on_session_end, &context).await;
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
all_succeeded = %hooks_result.all_succeeded,
|
||||
results = ?hooks_result.results.len(),
|
||||
"Executed on_session_end hooks"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the agent type for a session.
|
||||
async fn get_session_agent(&self, session_id: &str) -> Option<AgentId> {
|
||||
let sessions = self.sessions.lock().await;
|
||||
Self::session_ref(&sessions, session_id).map(|s| s.agent)
|
||||
}
|
||||
|
||||
async fn ensure_opencode_stream(
|
||||
|
|
@ -3405,6 +3618,12 @@ pub struct CreateSessionRequest {
|
|||
pub variant: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub agent_version: Option<String>,
|
||||
/// Hooks configuration for lifecycle events.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub hooks: Option<HooksConfig>,
|
||||
/// Working directory for hook execution.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub working_dir: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, JsonSchema)]
|
||||
|
|
|
|||
548
server/packages/sandbox-agent/tests/sessions/hooks.rs
Normal file
548
server/packages/sandbox-agent/tests/sessions/hooks.rs
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
// Hooks integration tests using the mock agent as the source of truth.
|
||||
include!("../common/http.rs");
|
||||
|
||||
use std::fs;
|
||||
|
||||
fn hooks_snapshot_suffix(prefix: &str) -> String {
|
||||
snapshot_name(prefix, Some(AgentId::Mock))
|
||||
}
|
||||
|
||||
fn assert_hooks_snapshot(prefix: &str, value: Value) {
|
||||
insta::with_settings!({
|
||||
snapshot_suffix => hooks_snapshot_suffix(prefix),
|
||||
}, {
|
||||
insta::assert_yaml_snapshot!(value);
|
||||
});
|
||||
}
|
||||
|
||||
/// Test that on_session_start hooks are executed when a session is created.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_session_start() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let marker_file = work_dir.path().join("session_started.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-session-start";
|
||||
let hooks = json!({
|
||||
"onSessionStart": [
|
||||
{
|
||||
"command": format!("echo 'session started' > {}", marker_file.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _response) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session with hooks");
|
||||
|
||||
// Give time for hook to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify the hook created the marker file
|
||||
assert!(marker_file.exists(), "session start hook should have created marker file");
|
||||
let content = fs::read_to_string(&marker_file).expect("read marker file");
|
||||
assert!(content.contains("session started"), "marker file should contain expected content");
|
||||
|
||||
assert_hooks_snapshot("session_start", json!({
|
||||
"hook_executed": marker_file.exists(),
|
||||
"content_valid": content.contains("session started")
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test that on_session_end hooks are executed when a session is terminated.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_session_end() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let marker_file = work_dir.path().join("session_ended.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-session-end";
|
||||
let hooks = json!({
|
||||
"onSessionEnd": [
|
||||
{
|
||||
"command": format!("echo 'session ended' > {}", marker_file.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Terminate the session
|
||||
let status = send_status(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}/terminate"),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::NO_CONTENT, "terminate session");
|
||||
|
||||
// Give time for hook to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify the hook created the marker file
|
||||
assert!(marker_file.exists(), "session end hook should have created marker file");
|
||||
let content = fs::read_to_string(&marker_file).expect("read marker file");
|
||||
assert!(content.contains("session ended"), "marker file should contain expected content");
|
||||
|
||||
assert_hooks_snapshot("session_end", json!({
|
||||
"hook_executed": marker_file.exists(),
|
||||
"content_valid": content.contains("session ended")
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test that on_message_start hooks are executed before processing a message.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_message_start() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let marker_file = work_dir.path().join("message_started.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-message-start";
|
||||
let hooks = json!({
|
||||
"onMessageStart": [
|
||||
{
|
||||
"command": format!("echo \"$SANDBOX_MESSAGE\" > {}", marker_file.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Send a message
|
||||
let status = send_status(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}/messages"),
|
||||
Some(json!({ "message": "test message content" })),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::NO_CONTENT, "send message");
|
||||
|
||||
// Give time for hook to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify the hook created the marker file with the message
|
||||
assert!(marker_file.exists(), "message start hook should have created marker file");
|
||||
let content = fs::read_to_string(&marker_file).expect("read marker file");
|
||||
assert!(content.contains("test message content"), "marker file should contain message");
|
||||
|
||||
assert_hooks_snapshot("message_start", json!({
|
||||
"hook_executed": marker_file.exists(),
|
||||
"content_valid": content.contains("test message content")
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test that on_message_end hooks are executed after a message is processed.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_message_end() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let marker_file = work_dir.path().join("message_ended.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-message-end";
|
||||
let hooks = json!({
|
||||
"onMessageEnd": [
|
||||
{
|
||||
"command": format!("echo 'message processed' > {}", marker_file.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Send a message and wait for completion
|
||||
let status = send_status(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}/messages"),
|
||||
Some(json!({ "message": "Reply with OK." })),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::NO_CONTENT, "send message");
|
||||
|
||||
// Wait for the mock agent to complete and hooks to run
|
||||
let events = poll_events_until(&app.app, session_id, std::time::Duration::from_secs(10)).await;
|
||||
|
||||
// Give extra time for hook to complete
|
||||
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
|
||||
|
||||
// Verify the hook created the marker file
|
||||
assert!(marker_file.exists(), "message end hook should have created marker file");
|
||||
let content = fs::read_to_string(&marker_file).expect("read marker file");
|
||||
assert!(content.contains("message processed"), "marker file should contain expected content");
|
||||
|
||||
assert_hooks_snapshot("message_end", json!({
|
||||
"hook_executed": marker_file.exists(),
|
||||
"content_valid": content.contains("message processed"),
|
||||
"event_count": events.len()
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test multiple hooks in sequence.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_multiple_in_sequence() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let file1 = work_dir.path().join("hook1.txt");
|
||||
let file2 = work_dir.path().join("hook2.txt");
|
||||
let file3 = work_dir.path().join("hook3.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-multiple";
|
||||
let hooks = json!({
|
||||
"onSessionStart": [
|
||||
{
|
||||
"command": format!("echo '1' > {}", file1.display()),
|
||||
"timeoutSecs": 5
|
||||
},
|
||||
{
|
||||
"command": format!("echo '2' > {}", file2.display()),
|
||||
"timeoutSecs": 5
|
||||
},
|
||||
{
|
||||
"command": format!("echo '3' > {}", file3.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Give time for hooks to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify all hooks ran
|
||||
assert!(file1.exists(), "hook 1 should have run");
|
||||
assert!(file2.exists(), "hook 2 should have run");
|
||||
assert!(file3.exists(), "hook 3 should have run");
|
||||
|
||||
assert_hooks_snapshot("multiple_hooks", json!({
|
||||
"hook1_executed": file1.exists(),
|
||||
"hook2_executed": file2.exists(),
|
||||
"hook3_executed": file3.exists()
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test that hook failures with continue_on_failure=false stop execution.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_stop_on_failure() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let file1 = work_dir.path().join("before_fail.txt");
|
||||
let file3 = work_dir.path().join("after_fail.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-stop-on-failure";
|
||||
let hooks = json!({
|
||||
"onSessionStart": [
|
||||
{
|
||||
"command": format!("echo 'first' > {}", file1.display()),
|
||||
"timeoutSecs": 5
|
||||
},
|
||||
{
|
||||
"command": "exit 1",
|
||||
"continueOnFailure": false,
|
||||
"timeoutSecs": 5
|
||||
},
|
||||
{
|
||||
"command": format!("echo 'third' > {}", file3.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Give time for hooks to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify first hook ran but third didn't (stopped at failure)
|
||||
assert!(file1.exists(), "first hook should have run");
|
||||
assert!(!file3.exists(), "third hook should NOT have run (stopped at failure)");
|
||||
|
||||
assert_hooks_snapshot("stop_on_failure", json!({
|
||||
"first_executed": file1.exists(),
|
||||
"third_executed": file3.exists()
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test that hook failures with continue_on_failure=true continue execution.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_continue_on_failure() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let file1 = work_dir.path().join("before_fail.txt");
|
||||
let file3 = work_dir.path().join("after_fail.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-continue-on-failure";
|
||||
let hooks = json!({
|
||||
"onSessionStart": [
|
||||
{
|
||||
"command": format!("echo 'first' > {}", file1.display()),
|
||||
"timeoutSecs": 5
|
||||
},
|
||||
{
|
||||
"command": "exit 1",
|
||||
"continueOnFailure": true,
|
||||
"timeoutSecs": 5
|
||||
},
|
||||
{
|
||||
"command": format!("echo 'third' > {}", file3.display()),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Give time for hooks to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify all hooks ran (continued past failure)
|
||||
assert!(file1.exists(), "first hook should have run");
|
||||
assert!(file3.exists(), "third hook should have run (continued past failure)");
|
||||
|
||||
assert_hooks_snapshot("continue_on_failure", json!({
|
||||
"first_executed": file1.exists(),
|
||||
"third_executed": file3.exists()
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test hooks with environment variables.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_environment_variables() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let env_file = work_dir.path().join("env_vars.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-env-vars";
|
||||
let hooks = json!({
|
||||
"onSessionStart": [
|
||||
{
|
||||
"command": format!(
|
||||
"echo \"session=$SANDBOX_SESSION_ID agent=$SANDBOX_AGENT mode=$SANDBOX_AGENT_MODE hook=$SANDBOX_HOOK_TYPE\" > {}",
|
||||
env_file.display()
|
||||
),
|
||||
"timeoutSecs": 5
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
|
||||
// Give time for hook to execute
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
|
||||
// Verify the environment variables were available
|
||||
assert!(env_file.exists(), "env file should exist");
|
||||
let content = fs::read_to_string(&env_file).expect("read env file");
|
||||
|
||||
assert!(content.contains(&format!("session={session_id}")), "should have session id");
|
||||
assert!(content.contains("agent=mock"), "should have agent");
|
||||
assert!(content.contains("hook=session_start"), "should have hook type");
|
||||
|
||||
assert_hooks_snapshot("env_vars", json!({
|
||||
"file_exists": env_file.exists(),
|
||||
"has_session_id": content.contains(&format!("session={session_id}")),
|
||||
"has_agent": content.contains("agent=mock"),
|
||||
"has_hook_type": content.contains("hook=session_start")
|
||||
}));
|
||||
}
|
||||
|
||||
/// Test full lifecycle with all hook types.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn hooks_full_lifecycle() {
|
||||
let work_dir = TempDir::new().expect("create work dir");
|
||||
let session_start = work_dir.path().join("session_start.txt");
|
||||
let message_start = work_dir.path().join("message_start.txt");
|
||||
let message_end = work_dir.path().join("message_end.txt");
|
||||
let session_end = work_dir.path().join("session_end.txt");
|
||||
|
||||
let app = TestApp::new();
|
||||
install_agent(&app.app, AgentId::Mock).await;
|
||||
|
||||
let session_id = "hooks-full-lifecycle";
|
||||
let hooks = json!({
|
||||
"onSessionStart": [{
|
||||
"command": format!("echo 'started' > {}", session_start.display()),
|
||||
"timeoutSecs": 5
|
||||
}],
|
||||
"onMessageStart": [{
|
||||
"command": format!("echo 'msg start' > {}", message_start.display()),
|
||||
"timeoutSecs": 5
|
||||
}],
|
||||
"onMessageEnd": [{
|
||||
"command": format!("echo 'msg end' > {}", message_end.display()),
|
||||
"timeoutSecs": 5
|
||||
}],
|
||||
"onSessionEnd": [{
|
||||
"command": format!("echo 'ended' > {}", session_end.display()),
|
||||
"timeoutSecs": 5
|
||||
}]
|
||||
});
|
||||
|
||||
// Create session (triggers onSessionStart)
|
||||
let (status, _) = send_json(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": "mock",
|
||||
"permissionMode": "bypass",
|
||||
"hooks": hooks,
|
||||
"workingDir": work_dir.path().to_str().unwrap()
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::OK, "create session");
|
||||
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
||||
assert!(session_start.exists(), "session start hook should run");
|
||||
|
||||
// Send message (triggers onMessageStart and onMessageEnd)
|
||||
let status = send_status(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}/messages"),
|
||||
Some(json!({ "message": "Reply with OK." })),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::NO_CONTENT, "send message");
|
||||
|
||||
// Wait for message processing
|
||||
let _ = poll_events_until(&app.app, session_id, std::time::Duration::from_secs(10)).await;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
|
||||
assert!(message_start.exists(), "message start hook should run");
|
||||
assert!(message_end.exists(), "message end hook should run");
|
||||
|
||||
// Terminate session (triggers onSessionEnd)
|
||||
let status = send_status(
|
||||
&app.app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}/terminate"),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::NO_CONTENT, "terminate session");
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
assert!(session_end.exists(), "session end hook should run");
|
||||
|
||||
assert_hooks_snapshot("full_lifecycle", json!({
|
||||
"session_start_executed": session_start.exists(),
|
||||
"message_start_executed": message_start.exists(),
|
||||
"message_end_executed": message_end.exists(),
|
||||
"session_end_executed": session_end.exists()
|
||||
}));
|
||||
}
|
||||
|
|
@ -4,3 +4,4 @@ mod permissions;
|
|||
mod questions;
|
||||
mod reasoning;
|
||||
mod status;
|
||||
mod hooks;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
first_executed: true
|
||||
third_executed: true
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
file_exists: true
|
||||
has_agent: true
|
||||
has_hook_type: true
|
||||
has_session_id: true
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
message_end_executed: true
|
||||
message_start_executed: true
|
||||
session_end_executed: true
|
||||
session_start_executed: true
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
content_valid: true
|
||||
event_count: 7
|
||||
hook_executed: true
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
content_valid: true
|
||||
hook_executed: true
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
hook1_executed: true
|
||||
hook2_executed: true
|
||||
hook3_executed: true
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
content_valid: true
|
||||
hook_executed: true
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
content_valid: true
|
||||
hook_executed: true
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
source: server/packages/sandbox-agent/tests/sessions/hooks.rs
|
||||
expression: value
|
||||
---
|
||||
first_executed: true
|
||||
third_executed: false
|
||||
Loading…
Add table
Add a link
Reference in a new issue