mirror of
https://github.com/harivansh-afk/sandbox-agent.git
synced 2026-04-16 22:03:52 +00:00
feat: enable inspector CORS by default
- Enable CORS for https://inspect.sandboxagent.dev by default - Add --no-inspector-cors flag to opt out - Additional --cors-allow-origin flags are now cumulative with inspector - Inspector now tries current origin first before localhost:2468 fallback
This commit is contained in:
parent
08d299a3ef
commit
8acb2bb078
2 changed files with 81 additions and 33 deletions
|
|
@ -42,13 +42,18 @@ const buildStubItem = (itemId: string, nativeItemId?: string | null): UniversalI
|
||||||
} as UniversalItem;
|
} as UniversalItem;
|
||||||
};
|
};
|
||||||
|
|
||||||
const getDefaultEndpoint = () => {
|
const DEFAULT_ENDPOINT = "http://localhost:2468";
|
||||||
return "http://localhost:2468";
|
|
||||||
|
const getCurrentOriginEndpoint = () => {
|
||||||
|
if (typeof window === "undefined") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return window.location.origin;
|
||||||
};
|
};
|
||||||
|
|
||||||
const getInitialConnection = () => {
|
const getInitialConnection = () => {
|
||||||
if (typeof window === "undefined") {
|
if (typeof window === "undefined") {
|
||||||
return { endpoint: "http://127.0.0.1:2468", token: "", headers: {} as Record<string, string> };
|
return { endpoint: "http://127.0.0.1:2468", token: "", headers: {} as Record<string, string>, hasUrlParam: false };
|
||||||
}
|
}
|
||||||
const params = new URLSearchParams(window.location.search);
|
const params = new URLSearchParams(window.location.search);
|
||||||
const urlParam = params.get("url")?.trim();
|
const urlParam = params.get("url")?.trim();
|
||||||
|
|
@ -62,10 +67,12 @@ const getInitialConnection = () => {
|
||||||
console.warn("Invalid headers query param, ignoring");
|
console.warn("Invalid headers query param, ignoring");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const hasUrlParam = urlParam != null && urlParam.length > 0;
|
||||||
return {
|
return {
|
||||||
endpoint: urlParam && urlParam.length > 0 ? urlParam : getDefaultEndpoint(),
|
endpoint: hasUrlParam ? urlParam : (getCurrentOriginEndpoint() ?? DEFAULT_ENDPOINT),
|
||||||
token: tokenParam,
|
token: tokenParam,
|
||||||
headers
|
headers,
|
||||||
|
hasUrlParam
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -132,7 +139,8 @@ export default function App() {
|
||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const createClient = useCallback(async () => {
|
const createClient = useCallback(async (overrideEndpoint?: string) => {
|
||||||
|
const targetEndpoint = overrideEndpoint ?? endpoint;
|
||||||
const fetchWithLog: typeof fetch = async (input, init) => {
|
const fetchWithLog: typeof fetch = async (input, init) => {
|
||||||
const method = init?.method ?? "GET";
|
const method = init?.method ?? "GET";
|
||||||
const url =
|
const url =
|
||||||
|
|
@ -175,14 +183,14 @@ export default function App() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const client = await SandboxAgent.connect({
|
const client = await SandboxAgent.connect({
|
||||||
baseUrl: endpoint,
|
baseUrl: targetEndpoint,
|
||||||
token: token || undefined,
|
token: token || undefined,
|
||||||
fetch: fetchWithLog,
|
fetch: fetchWithLog,
|
||||||
headers: Object.keys(extraHeaders).length > 0 ? extraHeaders : undefined
|
headers: Object.keys(extraHeaders).length > 0 ? extraHeaders : undefined
|
||||||
});
|
});
|
||||||
clientRef.current = client;
|
clientRef.current = client;
|
||||||
return client;
|
return client;
|
||||||
}, [endpoint, token, logRequest]);
|
}, [endpoint, token, extraHeaders, logRequest]);
|
||||||
|
|
||||||
const getClient = useCallback((): SandboxAgent => {
|
const getClient = useCallback((): SandboxAgent => {
|
||||||
if (!clientRef.current) {
|
if (!clientRef.current) {
|
||||||
|
|
@ -198,14 +206,17 @@ export default function App() {
|
||||||
return error instanceof Error ? error.message : fallback;
|
return error instanceof Error ? error.message : fallback;
|
||||||
};
|
};
|
||||||
|
|
||||||
const connectToDaemon = async (reportError: boolean) => {
|
const connectToDaemon = async (reportError: boolean, overrideEndpoint?: string) => {
|
||||||
setConnecting(true);
|
setConnecting(true);
|
||||||
if (reportError) {
|
if (reportError) {
|
||||||
setConnectError(null);
|
setConnectError(null);
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
const client = await createClient();
|
const client = await createClient(overrideEndpoint);
|
||||||
await client.getHealth();
|
await client.getHealth();
|
||||||
|
if (overrideEndpoint) {
|
||||||
|
setEndpoint(overrideEndpoint);
|
||||||
|
}
|
||||||
setConnected(true);
|
setConnected(true);
|
||||||
await refreshAgents();
|
await refreshAgents();
|
||||||
await fetchSessions();
|
await fetchSessions();
|
||||||
|
|
@ -219,6 +230,7 @@ export default function App() {
|
||||||
}
|
}
|
||||||
setConnected(false);
|
setConnected(false);
|
||||||
clientRef.current = null;
|
clientRef.current = null;
|
||||||
|
throw error;
|
||||||
} finally {
|
} finally {
|
||||||
setConnecting(false);
|
setConnecting(false);
|
||||||
}
|
}
|
||||||
|
|
@ -735,7 +747,37 @@ export default function App() {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let active = true;
|
let active = true;
|
||||||
const attempt = async () => {
|
const attempt = async () => {
|
||||||
await connectToDaemon(false);
|
const { hasUrlParam } = initialConnectionRef.current;
|
||||||
|
|
||||||
|
// If URL param was provided, just try that endpoint (don't fall back)
|
||||||
|
if (hasUrlParam) {
|
||||||
|
try {
|
||||||
|
await connectToDaemon(false);
|
||||||
|
} catch {
|
||||||
|
// Keep the URL param endpoint in the form even if connection failed
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No URL param: try current origin first
|
||||||
|
const originEndpoint = getCurrentOriginEndpoint();
|
||||||
|
if (originEndpoint) {
|
||||||
|
try {
|
||||||
|
await connectToDaemon(false, originEndpoint);
|
||||||
|
return;
|
||||||
|
} catch {
|
||||||
|
// Origin failed, continue to fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to localhost:2468
|
||||||
|
if (!active) return;
|
||||||
|
try {
|
||||||
|
await connectToDaemon(false, DEFAULT_ENDPOINT);
|
||||||
|
} catch {
|
||||||
|
// Keep localhost:2468 as the default in the form
|
||||||
|
setEndpoint(DEFAULT_ENDPOINT);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
attempt().catch(() => {
|
attempt().catch(() => {
|
||||||
if (!active) return;
|
if (!active) return;
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,10 @@ struct ServerArgs {
|
||||||
#[arg(long = "cors-allow-credentials", short = 'C')]
|
#[arg(long = "cors-allow-credentials", short = 'C')]
|
||||||
cors_allow_credentials: bool,
|
cors_allow_credentials: bool,
|
||||||
|
|
||||||
|
/// Disable default CORS for the inspector (https://inspect.sandboxagent.dev)
|
||||||
|
#[arg(long = "no-inspector-cors")]
|
||||||
|
no_inspector_cors: bool,
|
||||||
|
|
||||||
#[arg(long = "no-telemetry")]
|
#[arg(long = "no-telemetry")]
|
||||||
no_telemetry: bool,
|
no_telemetry: bool,
|
||||||
}
|
}
|
||||||
|
|
@ -381,9 +385,8 @@ fn run_server(cli: &Cli, server: &ServerArgs) -> Result<(), CliError> {
|
||||||
let state = Arc::new(AppState::new(auth, agent_manager));
|
let state = Arc::new(AppState::new(auth, agent_manager));
|
||||||
let (mut router, state) = build_router_with_state(state);
|
let (mut router, state) = build_router_with_state(state);
|
||||||
|
|
||||||
if let Some(cors) = build_cors_layer(server)? {
|
let cors = build_cors_layer(server)?;
|
||||||
router = router.layer(cors);
|
router = router.layer(cors);
|
||||||
}
|
|
||||||
|
|
||||||
let addr = format!("{}:{}", server.host, server.port);
|
let addr = format!("{}:{}", server.host, server.port);
|
||||||
let display_host = match server.host.as_str() {
|
let display_host = match server.host.as_str() {
|
||||||
|
|
@ -827,31 +830,33 @@ fn available_providers(credentials: &ExtractedCredentials) -> Vec<String> {
|
||||||
providers
|
providers
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cors_layer(server: &ServerArgs) -> Result<Option<CorsLayer>, CliError> {
|
const INSPECTOR_ORIGIN: &str = "https://inspect.sandboxagent.dev";
|
||||||
let has_config = !server.cors_allow_origin.is_empty()
|
|
||||||
|| !server.cors_allow_method.is_empty()
|
|
||||||
|| !server.cors_allow_header.is_empty()
|
|
||||||
|| server.cors_allow_credentials;
|
|
||||||
|
|
||||||
if !has_config {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
fn build_cors_layer(server: &ServerArgs) -> Result<CorsLayer, CliError> {
|
||||||
let mut cors = CorsLayer::new();
|
let mut cors = CorsLayer::new();
|
||||||
|
|
||||||
if server.cors_allow_origin.is_empty() {
|
// Build origins list: inspector by default + any additional origins
|
||||||
cors = cors.allow_origin(Any);
|
let mut origins = Vec::new();
|
||||||
|
if !server.no_inspector_cors {
|
||||||
|
let inspector_origin = INSPECTOR_ORIGIN
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| CliError::InvalidCorsOrigin(INSPECTOR_ORIGIN.to_string()))?;
|
||||||
|
origins.push(inspector_origin);
|
||||||
|
}
|
||||||
|
for origin in &server.cors_allow_origin {
|
||||||
|
let value = origin
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| CliError::InvalidCorsOrigin(origin.clone()))?;
|
||||||
|
origins.push(value);
|
||||||
|
}
|
||||||
|
if origins.is_empty() {
|
||||||
|
// No origins allowed - use permissive CORS with no origins (effectively disabled)
|
||||||
|
cors = cors.allow_origin(tower_http::cors::AllowOrigin::predicate(|_, _| false));
|
||||||
} else {
|
} else {
|
||||||
let mut origins = Vec::new();
|
|
||||||
for origin in &server.cors_allow_origin {
|
|
||||||
let value = origin
|
|
||||||
.parse()
|
|
||||||
.map_err(|_| CliError::InvalidCorsOrigin(origin.clone()))?;
|
|
||||||
origins.push(value);
|
|
||||||
}
|
|
||||||
cors = cors.allow_origin(origins);
|
cors = cors.allow_origin(origins);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Methods: allow any if not specified, otherwise use provided list
|
||||||
if server.cors_allow_method.is_empty() {
|
if server.cors_allow_method.is_empty() {
|
||||||
cors = cors.allow_methods(Any);
|
cors = cors.allow_methods(Any);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -865,6 +870,7 @@ fn build_cors_layer(server: &ServerArgs) -> Result<Option<CorsLayer>, CliError>
|
||||||
cors = cors.allow_methods(methods);
|
cors = cors.allow_methods(methods);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Headers: allow any if not specified, otherwise use provided list
|
||||||
if server.cors_allow_header.is_empty() {
|
if server.cors_allow_header.is_empty() {
|
||||||
cors = cors.allow_headers(Any);
|
cors = cors.allow_headers(Any);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -882,7 +888,7 @@ fn build_cors_layer(server: &ServerArgs) -> Result<Option<CorsLayer>, CliError>
|
||||||
cors = cors.allow_credentials(true);
|
cors = cors.allow_credentials(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Some(cors))
|
Ok(cors)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ClientContext {
|
struct ClientContext {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue