feat: refine process API — WebSocket binary protocol, SDK terminal session, updated tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Nathan Flurry 2026-03-06 12:12:24 -08:00
parent 6c91323ca6
commit 636eefb553
11 changed files with 700 additions and 512 deletions

View file

@ -8,7 +8,7 @@ use base64::Engine;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{broadcast, Mutex, RwLock};
use tokio::sync::{broadcast, Mutex, RwLock, Semaphore};
use sandbox_agent_error::SandboxError;
@ -119,6 +119,7 @@ pub struct ProcessRuntime {
struct ProcessRuntimeInner {
next_id: AtomicU64,
processes: RwLock<HashMap<String, Arc<ManagedProcess>>>,
run_once_semaphore: Semaphore,
}
#[derive(Debug)]
@ -182,6 +183,9 @@ impl ProcessRuntime {
inner: Arc::new(ProcessRuntimeInner {
next_id: AtomicU64::new(1),
processes: RwLock::new(HashMap::new()),
run_once_semaphore: Semaphore::new(
ProcessRuntimeConfig::default().max_concurrent_processes,
),
}),
}
}
@ -324,6 +328,14 @@ impl ProcessRuntime {
});
}
let _permit =
self.inner
.run_once_semaphore
.try_acquire()
.map_err(|_| SandboxError::Conflict {
message: "too many concurrent run_once operations".to_string(),
})?;
let config = self.get_config().await;
let mut timeout_ms = spec.timeout_ms.unwrap_or(config.default_run_timeout_ms);
if timeout_ms == 0 {
@ -331,7 +343,10 @@ impl ProcessRuntime {
}
timeout_ms = timeout_ms.min(config.max_run_timeout_ms);
let max_output_bytes = spec.max_output_bytes.unwrap_or(config.max_output_bytes);
let max_output_bytes = spec
.max_output_bytes
.unwrap_or(config.max_output_bytes)
.min(config.max_output_bytes);
let mut cmd = Command::new(&spec.command);
cmd.args(&spec.args)

View file

@ -50,6 +50,12 @@ pub use self::types::*;
const APPLICATION_JSON: &str = "application/json";
const TEXT_EVENT_STREAM: &str = "text/event-stream";
const CHANNEL_K8S_IO_PROTOCOL: &str = "channel.k8s.io";
const CH_STDIN: u8 = 0;
const CH_STDOUT: u8 = 1;
const CH_STATUS: u8 = 3;
const CH_RESIZE: u8 = 4;
const CH_CLOSE: u8 = 255;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BrandingMode {
@ -196,10 +202,6 @@ pub fn build_router_with_state(shared: Arc<AppState>) -> (Router, Arc<AppState>)
.route("/processes/:id/kill", post(post_v1_process_kill))
.route("/processes/:id/logs", get(get_v1_process_logs))
.route("/processes/:id/input", post(post_v1_process_input))
.route(
"/processes/:id/terminal/resize",
post(post_v1_process_terminal_resize),
)
.route(
"/processes/:id/terminal/ws",
get(get_v1_process_terminal_ws),
@ -344,7 +346,6 @@ pub async fn shutdown_servers(state: &Arc<AppState>) {
delete_v1_process,
get_v1_process_logs,
post_v1_process_input,
post_v1_process_terminal_resize,
get_v1_process_terminal_ws,
get_v1_config_mcp,
put_v1_config_mcp,
@ -394,8 +395,6 @@ pub async fn shutdown_servers(state: &Arc<AppState>) {
ProcessInputRequest,
ProcessInputResponse,
ProcessSignalQuery,
ProcessTerminalResizeRequest,
ProcessTerminalResizeResponse,
AcpPostQuery,
AcpServerInfo,
AcpServerListResponse,
@ -1602,51 +1601,13 @@ async fn post_v1_process_input(
Ok(Json(ProcessInputResponse { bytes_written }))
}
/// Resize a process terminal.
///
/// Sets the PTY window size (columns and rows) for a tty-mode process and
/// sends SIGWINCH so the child process can adapt.
#[utoipa::path(
post,
path = "/v1/processes/{id}/terminal/resize",
tag = "v1",
params(
("id" = String, Path, description = "Process ID")
),
request_body = ProcessTerminalResizeRequest,
responses(
(status = 200, description = "Resize accepted", body = ProcessTerminalResizeResponse),
(status = 400, description = "Invalid request", body = ProblemDetails),
(status = 404, description = "Unknown process", body = ProblemDetails),
(status = 409, description = "Not a terminal process", body = ProblemDetails),
(status = 501, description = "Process API unsupported on this platform", body = ProblemDetails)
)
)]
async fn post_v1_process_terminal_resize(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
Json(body): Json<ProcessTerminalResizeRequest>,
) -> Result<Json<ProcessTerminalResizeResponse>, ApiError> {
if !process_api_supported() {
return Err(process_api_not_supported().into());
}
state
.process_runtime()
.resize_terminal(&id, body.cols, body.rows)
.await?;
Ok(Json(ProcessTerminalResizeResponse {
cols: body.cols,
rows: body.rows,
}))
}
/// Open an interactive WebSocket terminal session.
///
/// Upgrades the connection to a WebSocket for bidirectional PTY I/O. Accepts
/// `access_token` query param for browser-based auth (WebSocket API cannot
/// send custom headers). Streams raw PTY output as binary frames and accepts
/// JSON control frames for input, resize, and close.
/// send custom headers). Uses the `channel.k8s.io` binary subprotocol:
/// channel 0 stdin, channel 1 stdout, channel 3 status JSON, channel 4 resize,
/// and channel 255 close.
#[utoipa::path(
get,
path = "/v1/processes/{id}/terminal/ws",
@ -1682,23 +1643,16 @@ async fn get_v1_process_terminal_ws(
}
Ok(ws
.protocols([CHANNEL_K8S_IO_PROTOCOL])
.on_upgrade(move |socket| process_terminal_ws_session(socket, runtime, id))
.into_response())
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
enum TerminalClientFrame {
Input {
data: String,
#[serde(default)]
encoding: Option<String>,
},
Resize {
cols: u16,
rows: u16,
},
Close,
#[serde(rename_all = "camelCase")]
struct TerminalResizePayload {
cols: u16,
rows: u16,
}
async fn process_terminal_ws_session(
@ -1706,7 +1660,7 @@ async fn process_terminal_ws_session(
runtime: Arc<ProcessRuntime>,
id: String,
) {
let _ = send_ws_json(
let _ = send_status_json(
&mut socket,
json!({
"type": "ready",
@ -1718,7 +1672,8 @@ async fn process_terminal_ws_session(
let mut log_rx = match runtime.subscribe_logs(&id).await {
Ok(rx) => rx,
Err(err) => {
let _ = send_ws_error(&mut socket, &err.to_string()).await;
let _ = send_status_error(&mut socket, &err.to_string()).await;
let _ = send_close_signal(&mut socket).await;
let _ = socket.close().await;
return;
}
@ -1729,43 +1684,57 @@ async fn process_terminal_ws_session(
tokio::select! {
ws_in = socket.recv() => {
match ws_in {
Some(Ok(Message::Binary(_))) => {
let _ = send_ws_error(&mut socket, "binary input is not supported; use text JSON frames").await;
}
Some(Ok(Message::Text(text))) => {
let parsed = serde_json::from_str::<TerminalClientFrame>(&text);
match parsed {
Ok(TerminalClientFrame::Input { data, encoding }) => {
let input = match decode_input_bytes(&data, encoding.as_deref().unwrap_or("utf8")) {
Ok(input) => input,
Err(err) => {
let _ = send_ws_error(&mut socket, &err.to_string()).await;
continue;
}
};
Some(Ok(Message::Binary(bytes))) => {
let Some((&channel, payload)) = bytes.split_first() else {
let _ = send_status_error(&mut socket, "invalid terminal frame: missing channel byte").await;
continue;
};
match channel {
CH_STDIN => {
let input = payload.to_vec();
let max_input = runtime.max_input_bytes().await;
if input.len() > max_input {
let _ = send_ws_error(&mut socket, &format!("input payload exceeds maxInputBytesPerRequest ({max_input})")).await;
let _ = send_status_error(&mut socket, &format!("input payload exceeds maxInputBytesPerRequest ({max_input})")).await;
continue;
}
if let Err(err) = runtime.write_input(&id, &input).await {
let _ = send_ws_error(&mut socket, &err.to_string()).await;
let _ = send_status_error(&mut socket, &err.to_string()).await;
}
}
Ok(TerminalClientFrame::Resize { cols, rows }) => {
if let Err(err) = runtime.resize_terminal(&id, cols, rows).await {
let _ = send_ws_error(&mut socket, &err.to_string()).await;
CH_RESIZE => {
let resize = match serde_json::from_slice::<TerminalResizePayload>(payload) {
Ok(resize) => resize,
Err(err) => {
let _ = send_status_error(&mut socket, &format!("invalid resize payload: {err}")).await;
continue;
}
};
if let Err(err) = runtime
.resize_terminal(&id, resize.cols, resize.rows)
.await
{
let _ = send_status_error(&mut socket, &err.to_string()).await;
}
}
Ok(TerminalClientFrame::Close) => {
CH_CLOSE => {
let _ = send_close_signal(&mut socket).await;
let _ = socket.close().await;
break;
}
Err(err) => {
let _ = send_ws_error(&mut socket, &format!("invalid terminal frame: {err}")).await;
_ => {
let _ = send_status_error(&mut socket, &format!("unsupported terminal channel: {channel}")).await;
}
}
}
Some(Ok(Message::Text(_))) => {
let _ = send_status_error(
&mut socket,
"text frames are not supported; use channel.k8s.io binary frames",
)
.await;
}
Some(Ok(Message::Ping(payload))) => {
let _ = socket.send(Message::Pong(payload)).await;
}
@ -1785,7 +1754,7 @@ async fn process_terminal_ws_session(
use base64::Engine;
BASE64_ENGINE.decode(&line.data).unwrap_or_default()
};
if socket.send(Message::Binary(bytes)).await.is_err() {
if send_channel_frame(&mut socket, CH_STDOUT, bytes).await.is_err() {
break;
}
}
@ -1796,7 +1765,7 @@ async fn process_terminal_ws_session(
_ = exit_poll.tick() => {
if let Ok(snapshot) = runtime.snapshot(&id).await {
if snapshot.status == ProcessStatus::Exited {
let _ = send_ws_json(
let _ = send_status_json(
&mut socket,
json!({
"type": "exit",
@ -1804,6 +1773,7 @@ async fn process_terminal_ws_session(
}),
)
.await;
let _ = send_close_signal(&mut socket).await;
let _ = socket.close().await;
break;
}
@ -1813,17 +1783,30 @@ async fn process_terminal_ws_session(
}
}
async fn send_ws_json(socket: &mut WebSocket, payload: Value) -> Result<(), ()> {
async fn send_channel_frame(
socket: &mut WebSocket,
channel: u8,
payload: impl Into<Vec<u8>>,
) -> Result<(), ()> {
let mut frame = vec![channel];
frame.extend(payload.into());
socket
.send(Message::Text(
serde_json::to_string(&payload).map_err(|_| ())?,
))
.send(Message::Binary(frame.into()))
.await
.map_err(|_| ())
}
async fn send_ws_error(socket: &mut WebSocket, message: &str) -> Result<(), ()> {
send_ws_json(
async fn send_status_json(socket: &mut WebSocket, payload: Value) -> Result<(), ()> {
send_channel_frame(
socket,
CH_STATUS,
serde_json::to_vec(&payload).map_err(|_| ())?,
)
.await
}
async fn send_status_error(socket: &mut WebSocket, message: &str) -> Result<(), ()> {
send_status_json(
socket,
json!({
"type": "error",
@ -1833,6 +1816,10 @@ async fn send_ws_error(socket: &mut WebSocket, message: &str) -> Result<(), ()>
.await
}
async fn send_close_signal(socket: &mut WebSocket) -> Result<(), ()> {
send_channel_frame(socket, CH_CLOSE, Vec::<u8>::new()).await
}
#[utoipa::path(
get,
path = "/v1/config/mcp",

View file

@ -512,20 +512,6 @@ pub struct ProcessSignalQuery {
pub wait_ms: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ProcessTerminalResizeRequest {
pub cols: u16,
pub rows: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ProcessTerminalResizeResponse {
pub cols: u16,
pub rows: u16,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ProcessWsQuery {

View file

@ -3,8 +3,17 @@ use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::Message;
const CHANNEL_K8S_IO_PROTOCOL: &str = "channel.k8s.io";
const CH_STDIN: u8 = 0;
const CH_STDOUT: u8 = 1;
const CH_STATUS: u8 = 3;
const CH_RESIZE: u8 = 4;
const CH_CLOSE: u8 = 255;
async fn wait_for_exited(test_app: &TestApp, process_id: &str) {
for _ in 0..30 {
let (status, _, body) = send_request(
@ -48,6 +57,19 @@ async fn recv_ws_message(
.expect("websocket frame")
}
fn make_channel_frame(channel: u8, payload: impl AsRef<[u8]>) -> Vec<u8> {
let payload = payload.as_ref();
let mut frame = Vec::with_capacity(payload.len() + 1);
frame.push(channel);
frame.extend_from_slice(payload);
frame
}
fn parse_channel_frame(bytes: &[u8]) -> (u8, &[u8]) {
let (&channel, payload) = bytes.split_first().expect("channel frame");
(channel, payload)
}
#[tokio::test]
async fn v1_processes_config_round_trip() {
let test_app = TestApp::new(AuthConfig::disabled());
@ -519,59 +541,84 @@ async fn v1_process_terminal_ws_e2e_is_deterministic() {
.expect("create process response");
assert_eq!(create_response.status(), reqwest::StatusCode::OK);
let create_body: Value = create_response.json().await.expect("create process json");
let process_id = create_body["id"]
.as_str()
.expect("process id")
.to_string();
let process_id = create_body["id"].as_str().expect("process id").to_string();
let ws_url = live_server.ws_url(&format!("/v1/processes/{process_id}/terminal/ws"));
let (mut ws, _) = connect_async(&ws_url)
.await
.expect("connect websocket");
let mut ws_request = ws_url.into_client_request().expect("ws request");
ws_request.headers_mut().insert(
"Sec-WebSocket-Protocol",
HeaderValue::from_static(CHANNEL_K8S_IO_PROTOCOL),
);
let (mut ws, response) = connect_async(ws_request).await.expect("connect websocket");
assert_eq!(
response
.headers()
.get("Sec-WebSocket-Protocol")
.and_then(|value| value.to_str().ok()),
Some(CHANNEL_K8S_IO_PROTOCOL)
);
let ready = recv_ws_message(&mut ws).await;
let ready_payload: Value = serde_json::from_str(ready.to_text().expect("ready text frame"))
.expect("ready json");
let ready_bytes = ready.into_data();
let (ready_channel, ready_payload) = parse_channel_frame(&ready_bytes);
assert_eq!(ready_channel, CH_STATUS);
let ready_payload: Value = serde_json::from_slice(ready_payload).expect("ready json");
assert_eq!(ready_payload["type"], "ready");
assert_eq!(ready_payload["processId"], process_id);
ws.send(Message::Text(
json!({
"type": "input",
"data": "hello from ws\n"
})
.to_string(),
ws.send(Message::Binary(
make_channel_frame(CH_STDIN, b"hello from ws\n").into(),
))
.await
.expect("send input frame");
let mut saw_binary_output = false;
ws.send(Message::Binary(
make_channel_frame(CH_RESIZE, br#"{"cols":120,"rows":40}"#).into(),
))
.await
.expect("send resize frame");
let mut saw_stdout = false;
let mut saw_exit = false;
let mut saw_close = false;
for _ in 0..10 {
let frame = recv_ws_message(&mut ws).await;
match frame {
Message::Binary(bytes) => {
let text = String::from_utf8_lossy(&bytes);
if text.contains("got:hello from ws") {
saw_binary_output = true;
let (channel, payload) = parse_channel_frame(&bytes);
match channel {
CH_STDOUT => {
let text = String::from_utf8_lossy(payload);
if text.contains("got:hello from ws") {
saw_stdout = true;
}
}
CH_STATUS => {
let payload: Value =
serde_json::from_slice(payload).expect("ws status json");
if payload["type"] == "exit" {
saw_exit = true;
} else {
assert_ne!(payload["type"], "error");
}
}
CH_CLOSE => {
assert!(payload.is_empty(), "close channel payload must be empty");
saw_close = true;
break;
}
other => panic!("unexpected websocket channel: {other}"),
}
}
Message::Text(text) => {
let payload: Value = serde_json::from_str(&text).expect("ws json");
if payload["type"] == "exit" {
saw_exit = true;
break;
}
assert_ne!(payload["type"], "error");
}
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => {}
_ => {}
}
}
assert!(saw_binary_output, "expected pty binary output over websocket");
assert!(saw_exit, "expected exit control frame over websocket");
assert!(saw_stdout, "expected pty stdout over websocket");
assert!(saw_exit, "expected exit status frame over websocket");
assert!(saw_close, "expected close channel frame over websocket");
let _ = ws.close(None).await;
@ -605,10 +652,7 @@ async fn v1_process_terminal_ws_auth_e2e() {
.expect("create process response");
assert_eq!(create_response.status(), reqwest::StatusCode::OK);
let create_body: Value = create_response.json().await.expect("create process json");
let process_id = create_body["id"]
.as_str()
.expect("process id")
.to_string();
let process_id = create_body["id"].as_str().expect("process id").to_string();
let unauth_ws_url = live_server.ws_url(&format!("/v1/processes/{process_id}/terminal/ws"));
let unauth_err = connect_async(&unauth_ws_url)
@ -624,25 +668,42 @@ async fn v1_process_terminal_ws_auth_e2e() {
let auth_ws_url = live_server.ws_url(&format!(
"/v1/processes/{process_id}/terminal/ws?access_token={token}"
));
let (mut ws, _) = connect_async(&auth_ws_url)
let mut ws_request = auth_ws_url.into_client_request().expect("ws request");
ws_request.headers_mut().insert(
"Sec-WebSocket-Protocol",
HeaderValue::from_static(CHANNEL_K8S_IO_PROTOCOL),
);
let (mut ws, response) = connect_async(ws_request)
.await
.expect("authenticated websocket handshake");
assert_eq!(
response
.headers()
.get("Sec-WebSocket-Protocol")
.and_then(|value| value.to_str().ok()),
Some(CHANNEL_K8S_IO_PROTOCOL)
);
let ready = recv_ws_message(&mut ws).await;
let ready_payload: Value = serde_json::from_str(ready.to_text().expect("ready text frame"))
.expect("ready json");
let ready_bytes = ready.into_data();
let (ready_channel, ready_payload) = parse_channel_frame(&ready_bytes);
assert_eq!(ready_channel, CH_STATUS);
let ready_payload: Value = serde_json::from_slice(ready_payload).expect("ready json");
assert_eq!(ready_payload["type"], "ready");
assert_eq!(ready_payload["processId"], process_id);
let _ = ws
.send(Message::Text(json!({ "type": "close" }).to_string()))
.send(Message::Binary(make_channel_frame(CH_CLOSE, []).into()))
.await;
let close = recv_ws_message(&mut ws).await;
let close_bytes = close.into_data();
let (close_channel, close_payload) = parse_channel_frame(&close_bytes);
assert_eq!(close_channel, CH_CLOSE);
assert!(close_payload.is_empty());
let _ = ws.close(None).await;
let kill_response = http
.post(live_server.http_url(&format!(
"/v1/processes/{process_id}/kill?waitMs=1000"
)))
.post(live_server.http_url(&format!("/v1/processes/{process_id}/kill?waitMs=1000")))
.bearer_auth(token)
.send()
.await