mirror of
https://github.com/harivansh-afk/sandbox-agent.git
synced 2026-04-20 05:04:49 +00:00
fix: separate claude turns by item
This commit is contained in:
parent
05459a2a2f
commit
53a06becb1
2 changed files with 59 additions and 277 deletions
|
|
@ -266,6 +266,7 @@ struct SessionState {
|
||||||
codex_sender: Option<mpsc::UnboundedSender<String>>,
|
codex_sender: Option<mpsc::UnboundedSender<String>>,
|
||||||
session_started_emitted: bool,
|
session_started_emitted: bool,
|
||||||
last_claude_message_id: Option<String>,
|
last_claude_message_id: Option<String>,
|
||||||
|
claude_message_counter: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|
@ -320,6 +321,7 @@ impl SessionState {
|
||||||
codex_sender: None,
|
codex_sender: None,
|
||||||
session_started_emitted: false,
|
session_started_emitted: false,
|
||||||
last_claude_message_id: None,
|
last_claude_message_id: None,
|
||||||
|
claude_message_counter: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -360,6 +362,20 @@ impl SessionState {
|
||||||
if let UniversalEventData::Item(ref mut data) = conversion.data {
|
if let UniversalEventData::Item(ref mut data) = conversion.data {
|
||||||
self.ensure_item_id(&mut data.item);
|
self.ensure_item_id(&mut data.item);
|
||||||
self.ensure_parent_id(&mut data.item);
|
self.ensure_parent_id(&mut data.item);
|
||||||
|
if conversion.event_type == UniversalEventType::ItemCompleted
|
||||||
|
&& !self.item_started.contains(&data.item.item_id)
|
||||||
|
{
|
||||||
|
let mut started_item = data.item.clone();
|
||||||
|
started_item.status = ItemStatus::InProgress;
|
||||||
|
conversions.push(
|
||||||
|
EventConversion::new(
|
||||||
|
UniversalEventType::ItemStarted,
|
||||||
|
UniversalEventData::Item(ItemEventData { item: started_item }),
|
||||||
|
)
|
||||||
|
.synthetic()
|
||||||
|
.with_native_session(conversion.native_session_id.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
if conversion.event_type == UniversalEventType::ItemCompleted
|
if conversion.event_type == UniversalEventType::ItemCompleted
|
||||||
&& data.item.kind == ItemKind::Message
|
&& data.item.kind == ItemKind::Message
|
||||||
&& !self.item_delta_seen.contains(&data.item.item_id)
|
&& !self.item_delta_seen.contains(&data.item.item_id)
|
||||||
|
|
@ -2200,29 +2216,47 @@ impl SessionManager {
|
||||||
};
|
};
|
||||||
let event_type = value.get("type").and_then(Value::as_str).unwrap_or("");
|
let event_type = value.get("type").and_then(Value::as_str).unwrap_or("");
|
||||||
if event_type == "assistant" {
|
if event_type == "assistant" {
|
||||||
if let Some(id) = value
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||||
|
let id = value
|
||||||
.get("message")
|
.get("message")
|
||||||
.and_then(|message| message.get("id"))
|
.and_then(|message| message.get("id"))
|
||||||
.and_then(Value::as_str)
|
.and_then(Value::as_str)
|
||||||
|
.map(|id| id.to_string())
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
session.claude_message_counter += 1;
|
||||||
|
let generated =
|
||||||
|
format!("{}_message_{}", session.session_id, session.claude_message_counter);
|
||||||
|
if let Some(message) = value.get_mut("message").and_then(Value::as_object_mut)
|
||||||
{
|
{
|
||||||
|
message.insert("id".to_string(), Value::String(generated.clone()));
|
||||||
|
} else if let Some(map) = value.as_object_mut() {
|
||||||
|
map.insert(
|
||||||
|
"message".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"id": generated
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
generated
|
||||||
|
});
|
||||||
|
session.last_claude_message_id = Some(id);
|
||||||
|
}
|
||||||
|
} else if event_type == "result" {
|
||||||
|
let has_message_id = value.get("message_id").is_some() || value.get("messageId").is_some();
|
||||||
let mut sessions = self.sessions.lock().await;
|
let mut sessions = self.sessions.lock().await;
|
||||||
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||||
session.last_claude_message_id = Some(id.to_string());
|
if !has_message_id {
|
||||||
}
|
let id = session.last_claude_message_id.take().unwrap_or_else(|| {
|
||||||
}
|
session.claude_message_counter += 1;
|
||||||
} else if event_type == "result"
|
format!("{}_message_{}", session.session_id, session.claude_message_counter)
|
||||||
&& value.get("message_id").is_none()
|
});
|
||||||
&& value.get("messageId").is_none()
|
|
||||||
{
|
|
||||||
let last_id = {
|
|
||||||
let sessions = self.sessions.lock().await;
|
|
||||||
Self::session_ref(&sessions, session_id)
|
|
||||||
.and_then(|session| session.last_claude_message_id.clone())
|
|
||||||
};
|
|
||||||
if let Some(id) = last_id {
|
|
||||||
if let Some(map) = value.as_object_mut() {
|
if let Some(map) = value.as_object_mut() {
|
||||||
map.insert("message_id".to_string(), Value::String(id));
|
map.insert("message_id".to_string(), Value::String(id));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
session.last_claude_message_id = None;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,6 @@
|
||||||
//! Tests for multi-turn conversations to validate session resumption behavior.
|
//! Tests for session resumption behavior.
|
||||||
//!
|
|
||||||
//! This test validates that:
|
|
||||||
//! 1. Sessions can handle multiple messages (multi-turn conversations)
|
|
||||||
//! 2. Agents that support resumption (Claude, Amp, OpenCode) can continue after process exit
|
|
||||||
//! 3. Codex supports multi-turn via the shared app-server model (single process, multiple threads)
|
|
||||||
//! 4. The mock agent correctly supports multi-turn as the reference implementation
|
|
||||||
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::Duration;
|
||||||
|
|
||||||
use axum::body::Body;
|
use axum::body::Body;
|
||||||
use axum::http::{Method, Request, StatusCode};
|
use axum::http::{Method, Request, StatusCode};
|
||||||
|
|
@ -17,14 +11,8 @@ use tempfile::TempDir;
|
||||||
|
|
||||||
use sandbox_agent::router::{build_router, AppState, AuthConfig};
|
use sandbox_agent::router::{build_router, AppState, AuthConfig};
|
||||||
use sandbox_agent_agent_management::agents::{AgentId, AgentManager};
|
use sandbox_agent_agent_management::agents::{AgentId, AgentManager};
|
||||||
use sandbox_agent_agent_management::testing::test_agents_from_env;
|
|
||||||
use sandbox_agent_agent_credentials::ExtractedCredentials;
|
|
||||||
use std::collections::BTreeMap;
|
|
||||||
use tower::util::ServiceExt;
|
use tower::util::ServiceExt;
|
||||||
|
|
||||||
const FIRST_PROMPT: &str = "Reply with exactly the word FIRST.";
|
|
||||||
const SECOND_PROMPT: &str = "Reply with exactly the word SECOND.";
|
|
||||||
|
|
||||||
struct TestApp {
|
struct TestApp {
|
||||||
app: Router,
|
app: Router,
|
||||||
_install_dir: TempDir,
|
_install_dir: TempDir,
|
||||||
|
|
@ -43,58 +31,6 @@ impl TestApp {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct EnvGuard {
|
|
||||||
saved: BTreeMap<String, Option<String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for EnvGuard {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
for (key, value) in &self.saved {
|
|
||||||
match value {
|
|
||||||
Some(value) => std::env::set_var(key, value),
|
|
||||||
None => std::env::remove_var(key),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_credentials(creds: &ExtractedCredentials) -> EnvGuard {
|
|
||||||
let keys = [
|
|
||||||
"ANTHROPIC_API_KEY",
|
|
||||||
"CLAUDE_API_KEY",
|
|
||||||
"OPENAI_API_KEY",
|
|
||||||
"CODEX_API_KEY",
|
|
||||||
];
|
|
||||||
let mut saved = BTreeMap::new();
|
|
||||||
for key in keys {
|
|
||||||
saved.insert(key.to_string(), std::env::var(key).ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
match creds.anthropic.as_ref() {
|
|
||||||
Some(cred) => {
|
|
||||||
std::env::set_var("ANTHROPIC_API_KEY", &cred.api_key);
|
|
||||||
std::env::set_var("CLAUDE_API_KEY", &cred.api_key);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
|
||||||
std::env::remove_var("CLAUDE_API_KEY");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match creds.openai.as_ref() {
|
|
||||||
Some(cred) => {
|
|
||||||
std::env::set_var("OPENAI_API_KEY", &cred.api_key);
|
|
||||||
std::env::set_var("CODEX_API_KEY", &cred.api_key);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
std::env::remove_var("OPENAI_API_KEY");
|
|
||||||
std::env::remove_var("CODEX_API_KEY");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
EnvGuard { saved }
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_json(
|
async fn send_json(
|
||||||
app: &Router,
|
app: &Router,
|
||||||
method: Method,
|
method: Method,
|
||||||
|
|
@ -126,37 +62,14 @@ async fn send_json(
|
||||||
(status, value)
|
(status, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_status(app: &Router, method: Method, path: &str, body: Option<Value>) -> StatusCode {
|
|
||||||
let (status, _) = send_json(app, method, path, body).await;
|
|
||||||
status
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn install_agent(app: &Router, agent: AgentId) {
|
|
||||||
let status = send_status(
|
|
||||||
app,
|
|
||||||
Method::POST,
|
|
||||||
&format!("/v1/agents/{}/install", agent.as_str()),
|
|
||||||
Some(json!({})),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
assert_eq!(status, StatusCode::NO_CONTENT, "install {agent}");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn test_permission_mode(agent: AgentId) -> &'static str {
|
|
||||||
match agent {
|
|
||||||
AgentId::Opencode => "default",
|
|
||||||
_ => "bypass",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_session(app: &Router, agent: AgentId, session_id: &str) {
|
async fn create_session(app: &Router, agent: AgentId, session_id: &str) {
|
||||||
let status = send_status(
|
let (status, _) = send_json(
|
||||||
app,
|
app,
|
||||||
Method::POST,
|
Method::POST,
|
||||||
&format!("/v1/sessions/{session_id}"),
|
&format!("/v1/sessions/{session_id}"),
|
||||||
Some(json!({
|
Some(json!({
|
||||||
"agent": agent.as_str(),
|
"agent": agent.as_str(),
|
||||||
"permissionMode": test_permission_mode(agent)
|
"permissionMode": "bypass"
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
@ -178,64 +91,6 @@ async fn send_message_with_status(
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wait for a specific number of assistant responses (item.completed with role=assistant)
|
|
||||||
async fn wait_for_n_responses(
|
|
||||||
app: &Router,
|
|
||||||
session_id: &str,
|
|
||||||
n: usize,
|
|
||||||
timeout: Duration,
|
|
||||||
) -> bool {
|
|
||||||
let start = Instant::now();
|
|
||||||
while start.elapsed() < timeout {
|
|
||||||
let path = format!("/v1/sessions/{session_id}/events?offset=0&limit=1000");
|
|
||||||
let (status, payload) = send_json(app, Method::GET, &path, None).await;
|
|
||||||
if status != StatusCode::OK {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
let events = payload
|
|
||||||
.get("events")
|
|
||||||
.and_then(Value::as_array)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let completed_count = events.iter().filter(|e| is_assistant_completed(e)).count();
|
|
||||||
if completed_count >= n {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for errors
|
|
||||||
for event in &events {
|
|
||||||
if is_error_event(event) {
|
|
||||||
eprintln!("Error event: {:?}", event);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
||||||
}
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wait for an assistant response (item.completed with role=assistant)
|
|
||||||
async fn wait_for_response(app: &Router, session_id: &str, timeout: Duration) -> bool {
|
|
||||||
wait_for_n_responses(app, session_id, 1, timeout).await
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_assistant_completed(event: &Value) -> bool {
|
|
||||||
event
|
|
||||||
.get("type")
|
|
||||||
.and_then(Value::as_str)
|
|
||||||
.map(|t| t == "item.completed")
|
|
||||||
.unwrap_or(false)
|
|
||||||
&& event
|
|
||||||
.get("data")
|
|
||||||
.and_then(|d| d.get("item"))
|
|
||||||
.and_then(|i| i.get("role"))
|
|
||||||
.and_then(Value::as_str)
|
|
||||||
.map(|r| r == "assistant")
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_session_ended(event: &Value) -> bool {
|
fn is_session_ended(event: &Value) -> bool {
|
||||||
event
|
event
|
||||||
.get("type")
|
.get("type")
|
||||||
|
|
@ -244,113 +99,6 @@ fn is_session_ended(event: &Value) -> bool {
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_error_event(event: &Value) -> bool {
|
|
||||||
matches!(
|
|
||||||
event.get("type").and_then(Value::as_str),
|
|
||||||
Some("error") | Some("agent.unparsed")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Count assistant responses in the event stream
|
|
||||||
async fn count_assistant_responses(app: &Router, session_id: &str) -> usize {
|
|
||||||
let path = format!("/v1/sessions/{session_id}/events?offset=0&limit=1000");
|
|
||||||
let (status, payload) = send_json(app, Method::GET, &path, None).await;
|
|
||||||
if status != StatusCode::OK {
|
|
||||||
eprintln!("Failed to get events: status={}", status);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
let events = payload
|
|
||||||
.get("events")
|
|
||||||
.and_then(Value::as_array)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Debug: print all event types
|
|
||||||
eprintln!("All events ({}):", events.len());
|
|
||||||
for (i, e) in events.iter().enumerate() {
|
|
||||||
let event_type = e.get("type").and_then(Value::as_str).unwrap_or("?");
|
|
||||||
let role = e
|
|
||||||
.get("data")
|
|
||||||
.and_then(|d| d.get("item"))
|
|
||||||
.and_then(|i| i.get("role"))
|
|
||||||
.and_then(Value::as_str)
|
|
||||||
.unwrap_or("-");
|
|
||||||
eprintln!(" [{}] type={}, role={}", i, event_type, role);
|
|
||||||
}
|
|
||||||
|
|
||||||
let count = events.iter().filter(|e| is_assistant_completed(e)).count();
|
|
||||||
eprintln!("Assistant completed count: {}", count);
|
|
||||||
count
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test multi-turn conversation for a specific agent
|
|
||||||
async fn test_multi_turn_for_agent(app: &Router, agent: AgentId) -> Result<(), String> {
|
|
||||||
let session_id = format!("multi-turn-{}", agent.as_str());
|
|
||||||
eprintln!("\n=== Testing multi-turn for {} ===", agent);
|
|
||||||
|
|
||||||
// Create session
|
|
||||||
create_session(app, agent, &session_id).await;
|
|
||||||
eprintln!("Session created: {}", session_id);
|
|
||||||
|
|
||||||
// Send first message
|
|
||||||
eprintln!("Sending first message...");
|
|
||||||
let (status, body) = send_message_with_status(app, &session_id, FIRST_PROMPT).await;
|
|
||||||
eprintln!("First message status: {}", status);
|
|
||||||
if status != StatusCode::NO_CONTENT {
|
|
||||||
return Err(format!(
|
|
||||||
"First message failed with status {}: {:?}",
|
|
||||||
status, body
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for first response
|
|
||||||
eprintln!("Waiting for first response...");
|
|
||||||
let got_first = wait_for_response(app, &session_id, Duration::from_secs(120)).await;
|
|
||||||
if !got_first {
|
|
||||||
return Err("Timed out waiting for first response".to_string());
|
|
||||||
}
|
|
||||||
eprintln!("Got first response");
|
|
||||||
|
|
||||||
// Small delay to ensure session state is updated
|
|
||||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
||||||
|
|
||||||
// Send second message - this is the critical test
|
|
||||||
eprintln!("Sending second message...");
|
|
||||||
let (status, body) = send_message_with_status(app, &session_id, SECOND_PROMPT).await;
|
|
||||||
eprintln!("Second message status: {}, body: {:?}", status, body);
|
|
||||||
if status != StatusCode::NO_CONTENT {
|
|
||||||
return Err(format!(
|
|
||||||
"Second message failed with status {}: {:?}",
|
|
||||||
status, body
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for second response - specifically wait for 2 completed responses
|
|
||||||
eprintln!("Waiting for second response (total 2)...");
|
|
||||||
let got_both = wait_for_n_responses(app, &session_id, 2, Duration::from_secs(120)).await;
|
|
||||||
if !got_both {
|
|
||||||
// Debug: show what we got
|
|
||||||
let response_count = count_assistant_responses(app, &session_id).await;
|
|
||||||
return Err(format!(
|
|
||||||
"Timed out waiting for second response (got {} completed)",
|
|
||||||
response_count
|
|
||||||
));
|
|
||||||
}
|
|
||||||
eprintln!("Got both responses");
|
|
||||||
|
|
||||||
// Verify we got two assistant responses
|
|
||||||
let response_count = count_assistant_responses(app, &session_id).await;
|
|
||||||
eprintln!("Final response count: {}", response_count);
|
|
||||||
if response_count < 2 {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected at least 2 assistant responses, got {}",
|
|
||||||
response_count
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test that verifies the session can be reopened after ending
|
/// Test that verifies the session can be reopened after ending
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn session_reopen_after_end() {
|
async fn session_reopen_after_end() {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue