diff --git a/frontend/packages/inspector/src/App.tsx b/frontend/packages/inspector/src/App.tsx index d687281..3dfe368 100644 --- a/frontend/packages/inspector/src/App.tsx +++ b/frontend/packages/inspector/src/App.tsx @@ -42,13 +42,18 @@ const buildStubItem = (itemId: string, nativeItemId?: string | null): UniversalI } as UniversalItem; }; -const getDefaultEndpoint = () => { - return "http://localhost:2468"; +const DEFAULT_ENDPOINT = "http://localhost:2468"; + +const getCurrentOriginEndpoint = () => { + if (typeof window === "undefined") { + return null; + } + return window.location.origin; }; const getInitialConnection = () => { if (typeof window === "undefined") { - return { endpoint: "http://127.0.0.1:2468", token: "", headers: {} as Record }; + return { endpoint: "http://127.0.0.1:2468", token: "", headers: {} as Record, hasUrlParam: false }; } const params = new URLSearchParams(window.location.search); const urlParam = params.get("url")?.trim(); @@ -62,10 +67,12 @@ const getInitialConnection = () => { console.warn("Invalid headers query param, ignoring"); } } + const hasUrlParam = urlParam != null && urlParam.length > 0; return { - endpoint: urlParam && urlParam.length > 0 ? urlParam : getDefaultEndpoint(), + endpoint: hasUrlParam ? urlParam : (getCurrentOriginEndpoint() ?? DEFAULT_ENDPOINT), token: tokenParam, - headers + headers, + hasUrlParam }; }; @@ -132,7 +139,8 @@ export default function App() { }); }, []); - const createClient = useCallback(async () => { + const createClient = useCallback(async (overrideEndpoint?: string) => { + const targetEndpoint = overrideEndpoint ?? endpoint; const fetchWithLog: typeof fetch = async (input, init) => { const method = init?.method ?? "GET"; const url = @@ -175,14 +183,14 @@ export default function App() { }; const client = await SandboxAgent.connect({ - baseUrl: endpoint, + baseUrl: targetEndpoint, token: token || undefined, fetch: fetchWithLog, headers: Object.keys(extraHeaders).length > 0 ? extraHeaders : undefined }); clientRef.current = client; return client; - }, [endpoint, token, logRequest]); + }, [endpoint, token, extraHeaders, logRequest]); const getClient = useCallback((): SandboxAgent => { if (!clientRef.current) { @@ -198,14 +206,17 @@ export default function App() { return error instanceof Error ? error.message : fallback; }; - const connectToDaemon = async (reportError: boolean) => { + const connectToDaemon = async (reportError: boolean, overrideEndpoint?: string) => { setConnecting(true); if (reportError) { setConnectError(null); } try { - const client = await createClient(); + const client = await createClient(overrideEndpoint); await client.getHealth(); + if (overrideEndpoint) { + setEndpoint(overrideEndpoint); + } setConnected(true); await refreshAgents(); await fetchSessions(); @@ -219,6 +230,7 @@ export default function App() { } setConnected(false); clientRef.current = null; + throw error; } finally { setConnecting(false); } @@ -735,7 +747,37 @@ export default function App() { useEffect(() => { let active = true; const attempt = async () => { - await connectToDaemon(false); + const { hasUrlParam } = initialConnectionRef.current; + + // If URL param was provided, just try that endpoint (don't fall back) + if (hasUrlParam) { + try { + await connectToDaemon(false); + } catch { + // Keep the URL param endpoint in the form even if connection failed + } + return; + } + + // No URL param: try current origin first + const originEndpoint = getCurrentOriginEndpoint(); + if (originEndpoint) { + try { + await connectToDaemon(false, originEndpoint); + return; + } catch { + // Origin failed, continue to fallback + } + } + + // Fall back to localhost:2468 + if (!active) return; + try { + await connectToDaemon(false, DEFAULT_ENDPOINT); + } catch { + // Keep localhost:2468 as the default in the form + setEndpoint(DEFAULT_ENDPOINT); + } }; attempt().catch(() => { if (!active) return; diff --git a/server/packages/sandbox-agent/src/main.rs b/server/packages/sandbox-agent/src/main.rs index 072063f..d7fb338 100644 --- a/server/packages/sandbox-agent/src/main.rs +++ b/server/packages/sandbox-agent/src/main.rs @@ -77,6 +77,10 @@ struct ServerArgs { #[arg(long = "cors-allow-credentials", short = 'C')] cors_allow_credentials: bool, + /// Disable default CORS for the inspector (https://inspect.sandboxagent.dev) + #[arg(long = "no-inspector-cors")] + no_inspector_cors: bool, + #[arg(long = "no-telemetry")] no_telemetry: bool, } @@ -381,9 +385,8 @@ fn run_server(cli: &Cli, server: &ServerArgs) -> Result<(), CliError> { let state = Arc::new(AppState::new(auth, agent_manager)); let (mut router, state) = build_router_with_state(state); - if let Some(cors) = build_cors_layer(server)? { - router = router.layer(cors); - } + let cors = build_cors_layer(server)?; + router = router.layer(cors); let addr = format!("{}:{}", server.host, server.port); let display_host = match server.host.as_str() { @@ -827,31 +830,33 @@ fn available_providers(credentials: &ExtractedCredentials) -> Vec { providers } -fn build_cors_layer(server: &ServerArgs) -> Result, CliError> { - let has_config = !server.cors_allow_origin.is_empty() - || !server.cors_allow_method.is_empty() - || !server.cors_allow_header.is_empty() - || server.cors_allow_credentials; - - if !has_config { - return Ok(None); - } +const INSPECTOR_ORIGIN: &str = "https://inspect.sandboxagent.dev"; +fn build_cors_layer(server: &ServerArgs) -> Result { let mut cors = CorsLayer::new(); - if server.cors_allow_origin.is_empty() { - cors = cors.allow_origin(Any); + // Build origins list: inspector by default + any additional origins + let mut origins = Vec::new(); + if !server.no_inspector_cors { + let inspector_origin = INSPECTOR_ORIGIN + .parse() + .map_err(|_| CliError::InvalidCorsOrigin(INSPECTOR_ORIGIN.to_string()))?; + origins.push(inspector_origin); + } + for origin in &server.cors_allow_origin { + let value = origin + .parse() + .map_err(|_| CliError::InvalidCorsOrigin(origin.clone()))?; + origins.push(value); + } + if origins.is_empty() { + // No origins allowed - use permissive CORS with no origins (effectively disabled) + cors = cors.allow_origin(tower_http::cors::AllowOrigin::predicate(|_, _| false)); } else { - let mut origins = Vec::new(); - for origin in &server.cors_allow_origin { - let value = origin - .parse() - .map_err(|_| CliError::InvalidCorsOrigin(origin.clone()))?; - origins.push(value); - } cors = cors.allow_origin(origins); } + // Methods: allow any if not specified, otherwise use provided list if server.cors_allow_method.is_empty() { cors = cors.allow_methods(Any); } else { @@ -865,6 +870,7 @@ fn build_cors_layer(server: &ServerArgs) -> Result, CliError> cors = cors.allow_methods(methods); } + // Headers: allow any if not specified, otherwise use provided list if server.cors_allow_header.is_empty() { cors = cors.allow_headers(Any); } else { @@ -882,7 +888,7 @@ fn build_cors_layer(server: &ServerArgs) -> Result, CliError> cors = cors.allow_credentials(true); } - Ok(Some(cors)) + Ok(cors) } struct ClientContext {