mirror of
https://github.com/harivansh-afk/sandbox-agent.git
synced 2026-04-15 22:03:48 +00:00
fix: separate claude turns by item
This commit is contained in:
parent
05459a2a2f
commit
53a06becb1
2 changed files with 59 additions and 277 deletions
|
|
@ -266,6 +266,7 @@ struct SessionState {
|
|||
codex_sender: Option<mpsc::UnboundedSender<String>>,
|
||||
session_started_emitted: bool,
|
||||
last_claude_message_id: Option<String>,
|
||||
claude_message_counter: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -320,6 +321,7 @@ impl SessionState {
|
|||
codex_sender: None,
|
||||
session_started_emitted: false,
|
||||
last_claude_message_id: None,
|
||||
claude_message_counter: 0,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -360,6 +362,20 @@ impl SessionState {
|
|||
if let UniversalEventData::Item(ref mut data) = conversion.data {
|
||||
self.ensure_item_id(&mut data.item);
|
||||
self.ensure_parent_id(&mut data.item);
|
||||
if conversion.event_type == UniversalEventType::ItemCompleted
|
||||
&& !self.item_started.contains(&data.item.item_id)
|
||||
{
|
||||
let mut started_item = data.item.clone();
|
||||
started_item.status = ItemStatus::InProgress;
|
||||
conversions.push(
|
||||
EventConversion::new(
|
||||
UniversalEventType::ItemStarted,
|
||||
UniversalEventData::Item(ItemEventData { item: started_item }),
|
||||
)
|
||||
.synthetic()
|
||||
.with_native_session(conversion.native_session_id.clone()),
|
||||
);
|
||||
}
|
||||
if conversion.event_type == UniversalEventType::ItemCompleted
|
||||
&& data.item.kind == ItemKind::Message
|
||||
&& !self.item_delta_seen.contains(&data.item.item_id)
|
||||
|
|
@ -2200,28 +2216,46 @@ impl SessionManager {
|
|||
};
|
||||
let event_type = value.get("type").and_then(Value::as_str).unwrap_or("");
|
||||
if event_type == "assistant" {
|
||||
if let Some(id) = value
|
||||
.get("message")
|
||||
.and_then(|message| message.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
{
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||
session.last_claude_message_id = Some(id.to_string());
|
||||
}
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||
let id = value
|
||||
.get("message")
|
||||
.and_then(|message| message.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
session.claude_message_counter += 1;
|
||||
let generated =
|
||||
format!("{}_message_{}", session.session_id, session.claude_message_counter);
|
||||
if let Some(message) = value.get_mut("message").and_then(Value::as_object_mut)
|
||||
{
|
||||
message.insert("id".to_string(), Value::String(generated.clone()));
|
||||
} else if let Some(map) = value.as_object_mut() {
|
||||
map.insert(
|
||||
"message".to_string(),
|
||||
serde_json::json!({
|
||||
"id": generated
|
||||
}),
|
||||
);
|
||||
}
|
||||
generated
|
||||
});
|
||||
session.last_claude_message_id = Some(id);
|
||||
}
|
||||
} else if event_type == "result"
|
||||
&& value.get("message_id").is_none()
|
||||
&& value.get("messageId").is_none()
|
||||
{
|
||||
let last_id = {
|
||||
let sessions = self.sessions.lock().await;
|
||||
Self::session_ref(&sessions, session_id)
|
||||
.and_then(|session| session.last_claude_message_id.clone())
|
||||
};
|
||||
if let Some(id) = last_id {
|
||||
if let Some(map) = value.as_object_mut() {
|
||||
map.insert("message_id".to_string(), Value::String(id));
|
||||
} else if event_type == "result" {
|
||||
let has_message_id = value.get("message_id").is_some() || value.get("messageId").is_some();
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = Self::session_mut(&mut sessions, session_id) {
|
||||
if !has_message_id {
|
||||
let id = session.last_claude_message_id.take().unwrap_or_else(|| {
|
||||
session.claude_message_counter += 1;
|
||||
format!("{}_message_{}", session.session_id, session.claude_message_counter)
|
||||
});
|
||||
if let Some(map) = value.as_object_mut() {
|
||||
map.insert("message_id".to_string(), Value::String(id));
|
||||
}
|
||||
} else {
|
||||
session.last_claude_message_id = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
//! Tests for multi-turn conversations to validate session resumption behavior.
|
||||
//!
|
||||
//! This test validates that:
|
||||
//! 1. Sessions can handle multiple messages (multi-turn conversations)
|
||||
//! 2. Agents that support resumption (Claude, Amp, OpenCode) can continue after process exit
|
||||
//! 3. Codex supports multi-turn via the shared app-server model (single process, multiple threads)
|
||||
//! 4. The mock agent correctly supports multi-turn as the reference implementation
|
||||
//! Tests for session resumption behavior.
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::http::{Method, Request, StatusCode};
|
||||
|
|
@ -17,14 +11,8 @@ use tempfile::TempDir;
|
|||
|
||||
use sandbox_agent::router::{build_router, AppState, AuthConfig};
|
||||
use sandbox_agent_agent_management::agents::{AgentId, AgentManager};
|
||||
use sandbox_agent_agent_management::testing::test_agents_from_env;
|
||||
use sandbox_agent_agent_credentials::ExtractedCredentials;
|
||||
use std::collections::BTreeMap;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
const FIRST_PROMPT: &str = "Reply with exactly the word FIRST.";
|
||||
const SECOND_PROMPT: &str = "Reply with exactly the word SECOND.";
|
||||
|
||||
struct TestApp {
|
||||
app: Router,
|
||||
_install_dir: TempDir,
|
||||
|
|
@ -43,58 +31,6 @@ impl TestApp {
|
|||
}
|
||||
}
|
||||
|
||||
struct EnvGuard {
|
||||
saved: BTreeMap<String, Option<String>>,
|
||||
}
|
||||
|
||||
impl Drop for EnvGuard {
|
||||
fn drop(&mut self) {
|
||||
for (key, value) in &self.saved {
|
||||
match value {
|
||||
Some(value) => std::env::set_var(key, value),
|
||||
None => std::env::remove_var(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_credentials(creds: &ExtractedCredentials) -> EnvGuard {
|
||||
let keys = [
|
||||
"ANTHROPIC_API_KEY",
|
||||
"CLAUDE_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CODEX_API_KEY",
|
||||
];
|
||||
let mut saved = BTreeMap::new();
|
||||
for key in keys {
|
||||
saved.insert(key.to_string(), std::env::var(key).ok());
|
||||
}
|
||||
|
||||
match creds.anthropic.as_ref() {
|
||||
Some(cred) => {
|
||||
std::env::set_var("ANTHROPIC_API_KEY", &cred.api_key);
|
||||
std::env::set_var("CLAUDE_API_KEY", &cred.api_key);
|
||||
}
|
||||
None => {
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
std::env::remove_var("CLAUDE_API_KEY");
|
||||
}
|
||||
}
|
||||
|
||||
match creds.openai.as_ref() {
|
||||
Some(cred) => {
|
||||
std::env::set_var("OPENAI_API_KEY", &cred.api_key);
|
||||
std::env::set_var("CODEX_API_KEY", &cred.api_key);
|
||||
}
|
||||
None => {
|
||||
std::env::remove_var("OPENAI_API_KEY");
|
||||
std::env::remove_var("CODEX_API_KEY");
|
||||
}
|
||||
}
|
||||
|
||||
EnvGuard { saved }
|
||||
}
|
||||
|
||||
async fn send_json(
|
||||
app: &Router,
|
||||
method: Method,
|
||||
|
|
@ -126,37 +62,14 @@ async fn send_json(
|
|||
(status, value)
|
||||
}
|
||||
|
||||
async fn send_status(app: &Router, method: Method, path: &str, body: Option<Value>) -> StatusCode {
|
||||
let (status, _) = send_json(app, method, path, body).await;
|
||||
status
|
||||
}
|
||||
|
||||
async fn install_agent(app: &Router, agent: AgentId) {
|
||||
let status = send_status(
|
||||
app,
|
||||
Method::POST,
|
||||
&format!("/v1/agents/{}/install", agent.as_str()),
|
||||
Some(json!({})),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(status, StatusCode::NO_CONTENT, "install {agent}");
|
||||
}
|
||||
|
||||
fn test_permission_mode(agent: AgentId) -> &'static str {
|
||||
match agent {
|
||||
AgentId::Opencode => "default",
|
||||
_ => "bypass",
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_session(app: &Router, agent: AgentId, session_id: &str) {
|
||||
let status = send_status(
|
||||
let (status, _) = send_json(
|
||||
app,
|
||||
Method::POST,
|
||||
&format!("/v1/sessions/{session_id}"),
|
||||
Some(json!({
|
||||
"agent": agent.as_str(),
|
||||
"permissionMode": test_permission_mode(agent)
|
||||
"permissionMode": "bypass"
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
|
|
@ -178,64 +91,6 @@ async fn send_message_with_status(
|
|||
.await
|
||||
}
|
||||
|
||||
/// Wait for a specific number of assistant responses (item.completed with role=assistant)
|
||||
async fn wait_for_n_responses(
|
||||
app: &Router,
|
||||
session_id: &str,
|
||||
n: usize,
|
||||
timeout: Duration,
|
||||
) -> bool {
|
||||
let start = Instant::now();
|
||||
while start.elapsed() < timeout {
|
||||
let path = format!("/v1/sessions/{session_id}/events?offset=0&limit=1000");
|
||||
let (status, payload) = send_json(app, Method::GET, &path, None).await;
|
||||
if status != StatusCode::OK {
|
||||
return false;
|
||||
}
|
||||
let events = payload
|
||||
.get("events")
|
||||
.and_then(Value::as_array)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let completed_count = events.iter().filter(|e| is_assistant_completed(e)).count();
|
||||
if completed_count >= n {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
for event in &events {
|
||||
if is_error_event(event) {
|
||||
eprintln!("Error event: {:?}", event);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Wait for an assistant response (item.completed with role=assistant)
|
||||
async fn wait_for_response(app: &Router, session_id: &str, timeout: Duration) -> bool {
|
||||
wait_for_n_responses(app, session_id, 1, timeout).await
|
||||
}
|
||||
|
||||
fn is_assistant_completed(event: &Value) -> bool {
|
||||
event
|
||||
.get("type")
|
||||
.and_then(Value::as_str)
|
||||
.map(|t| t == "item.completed")
|
||||
.unwrap_or(false)
|
||||
&& event
|
||||
.get("data")
|
||||
.and_then(|d| d.get("item"))
|
||||
.and_then(|i| i.get("role"))
|
||||
.and_then(Value::as_str)
|
||||
.map(|r| r == "assistant")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn is_session_ended(event: &Value) -> bool {
|
||||
event
|
||||
.get("type")
|
||||
|
|
@ -244,113 +99,6 @@ fn is_session_ended(event: &Value) -> bool {
|
|||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn is_error_event(event: &Value) -> bool {
|
||||
matches!(
|
||||
event.get("type").and_then(Value::as_str),
|
||||
Some("error") | Some("agent.unparsed")
|
||||
)
|
||||
}
|
||||
|
||||
/// Count assistant responses in the event stream
|
||||
async fn count_assistant_responses(app: &Router, session_id: &str) -> usize {
|
||||
let path = format!("/v1/sessions/{session_id}/events?offset=0&limit=1000");
|
||||
let (status, payload) = send_json(app, Method::GET, &path, None).await;
|
||||
if status != StatusCode::OK {
|
||||
eprintln!("Failed to get events: status={}", status);
|
||||
return 0;
|
||||
}
|
||||
let events = payload
|
||||
.get("events")
|
||||
.and_then(Value::as_array)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Debug: print all event types
|
||||
eprintln!("All events ({}):", events.len());
|
||||
for (i, e) in events.iter().enumerate() {
|
||||
let event_type = e.get("type").and_then(Value::as_str).unwrap_or("?");
|
||||
let role = e
|
||||
.get("data")
|
||||
.and_then(|d| d.get("item"))
|
||||
.and_then(|i| i.get("role"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("-");
|
||||
eprintln!(" [{}] type={}, role={}", i, event_type, role);
|
||||
}
|
||||
|
||||
let count = events.iter().filter(|e| is_assistant_completed(e)).count();
|
||||
eprintln!("Assistant completed count: {}", count);
|
||||
count
|
||||
}
|
||||
|
||||
/// Test multi-turn conversation for a specific agent
|
||||
async fn test_multi_turn_for_agent(app: &Router, agent: AgentId) -> Result<(), String> {
|
||||
let session_id = format!("multi-turn-{}", agent.as_str());
|
||||
eprintln!("\n=== Testing multi-turn for {} ===", agent);
|
||||
|
||||
// Create session
|
||||
create_session(app, agent, &session_id).await;
|
||||
eprintln!("Session created: {}", session_id);
|
||||
|
||||
// Send first message
|
||||
eprintln!("Sending first message...");
|
||||
let (status, body) = send_message_with_status(app, &session_id, FIRST_PROMPT).await;
|
||||
eprintln!("First message status: {}", status);
|
||||
if status != StatusCode::NO_CONTENT {
|
||||
return Err(format!(
|
||||
"First message failed with status {}: {:?}",
|
||||
status, body
|
||||
));
|
||||
}
|
||||
|
||||
// Wait for first response
|
||||
eprintln!("Waiting for first response...");
|
||||
let got_first = wait_for_response(app, &session_id, Duration::from_secs(120)).await;
|
||||
if !got_first {
|
||||
return Err("Timed out waiting for first response".to_string());
|
||||
}
|
||||
eprintln!("Got first response");
|
||||
|
||||
// Small delay to ensure session state is updated
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Send second message - this is the critical test
|
||||
eprintln!("Sending second message...");
|
||||
let (status, body) = send_message_with_status(app, &session_id, SECOND_PROMPT).await;
|
||||
eprintln!("Second message status: {}, body: {:?}", status, body);
|
||||
if status != StatusCode::NO_CONTENT {
|
||||
return Err(format!(
|
||||
"Second message failed with status {}: {:?}",
|
||||
status, body
|
||||
));
|
||||
}
|
||||
|
||||
// Wait for second response - specifically wait for 2 completed responses
|
||||
eprintln!("Waiting for second response (total 2)...");
|
||||
let got_both = wait_for_n_responses(app, &session_id, 2, Duration::from_secs(120)).await;
|
||||
if !got_both {
|
||||
// Debug: show what we got
|
||||
let response_count = count_assistant_responses(app, &session_id).await;
|
||||
return Err(format!(
|
||||
"Timed out waiting for second response (got {} completed)",
|
||||
response_count
|
||||
));
|
||||
}
|
||||
eprintln!("Got both responses");
|
||||
|
||||
// Verify we got two assistant responses
|
||||
let response_count = count_assistant_responses(app, &session_id).await;
|
||||
eprintln!("Final response count: {}", response_count);
|
||||
if response_count < 2 {
|
||||
return Err(format!(
|
||||
"Expected at least 2 assistant responses, got {}",
|
||||
response_count
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that verifies the session can be reopened after ending
|
||||
#[tokio::test]
|
||||
async fn session_reopen_after_end() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue