diff --git a/.turbo b/.turbo new file mode 120000 index 0000000..0b7d9ca --- /dev/null +++ b/.turbo @@ -0,0 +1 @@ +/home/nathan/sandbox-agent/.turbo \ No newline at end of file diff --git a/dist b/dist new file mode 120000 index 0000000..f02d77f --- /dev/null +++ b/dist @@ -0,0 +1 @@ +/home/nathan/sandbox-agent/dist \ No newline at end of file diff --git a/node_modules b/node_modules new file mode 120000 index 0000000..501480b --- /dev/null +++ b/node_modules @@ -0,0 +1 @@ +/home/nathan/sandbox-agent/node_modules \ No newline at end of file diff --git a/server/packages/sandbox-agent/src/lib.rs b/server/packages/sandbox-agent/src/lib.rs index 8c11343..07d938d 100644 --- a/server/packages/sandbox-agent/src/lib.rs +++ b/server/packages/sandbox-agent/src/lib.rs @@ -3,6 +3,7 @@ mod agent_server_logs; pub mod credentials; pub mod opencode_compat; +pub mod provider_auth; pub mod router; pub mod server_logs; pub mod telemetry; diff --git a/server/packages/sandbox-agent/src/opencode_compat.rs b/server/packages/sandbox-agent/src/opencode_compat.rs index 55b7050..190ff73 100644 --- a/server/packages/sandbox-agent/src/opencode_compat.rs +++ b/server/packages/sandbox-agent/src/opencode_compat.rs @@ -24,6 +24,7 @@ use tokio::time::interval; use utoipa::{IntoParams, OpenApi, ToSchema}; use crate::router::{AppState, CreateSessionRequest, PermissionReply}; +use crate::provider_auth::{ProviderAuth, ProviderAuthStore}; use sandbox_agent_error::SandboxError; use sandbox_agent_agent_management::agents::AgentId; use sandbox_agent_universal_agent_schema::{ @@ -41,6 +42,23 @@ const OPENCODE_PROVIDER_ID: &str = "sandbox-agent"; const OPENCODE_PROVIDER_NAME: &str = "Sandbox Agent"; const OPENCODE_DEFAULT_MODEL_ID: &str = "mock"; const OPENCODE_DEFAULT_AGENT_MODE: &str = "build"; +const PROVIDER_ANTHROPIC: &str = "anthropic"; +const PROVIDER_OPENAI: &str = "openai"; + +#[derive(Clone, Debug)] +struct ProviderAuthMethod { + kind: &'static str, + label: &'static str, +} + +#[derive(Clone, Debug)] +struct ProviderDefinition { + id: &'static str, + name: &'static str, + env: &'static [&'static str], + models: Vec, + auth_methods: Vec, +} #[derive(Clone, Debug)] struct OpenCodeCompatConfig { @@ -555,6 +573,19 @@ struct PermissionGlobalReplyRequest { reply: Option, } +#[derive(Debug, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +struct ProviderOauthRequest { + method: Option, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +struct ProviderOauthCallbackRequest { + method: Option, + code: Option, +} + #[derive(Debug, Serialize, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] struct QuestionReplyBody { @@ -585,6 +616,75 @@ fn available_agent_ids() -> Vec { ] } +fn provider_registry() -> Vec { + vec![ + ProviderDefinition { + id: OPENCODE_PROVIDER_ID, + name: OPENCODE_PROVIDER_NAME, + env: &[], + models: available_agent_ids(), + auth_methods: Vec::new(), + }, + ProviderDefinition { + id: PROVIDER_ANTHROPIC, + name: "Anthropic", + env: &["ANTHROPIC_API_KEY"], + models: vec![AgentId::Claude, AgentId::Amp], + auth_methods: vec![ + ProviderAuthMethod { + kind: "api", + label: "API Key", + }, + ProviderAuthMethod { + kind: "oauth", + label: "OAuth", + }, + ], + }, + ProviderDefinition { + id: PROVIDER_OPENAI, + name: "OpenAI", + env: &["OPENAI_API_KEY"], + models: vec![AgentId::Codex], + auth_methods: vec![ProviderAuthMethod { + kind: "api", + label: "API Key", + }], + }, + ] +} + +fn provider_models_map(models: &[AgentId]) -> Value { + let mut map = serde_json::Map::new(); + for agent in models { + map.insert(agent.as_str().to_string(), model_summary_entry(*agent)); + } + Value::Object(map) +} + +fn provider_auth_method_values(methods: &[ProviderAuthMethod]) -> Vec { + methods + .iter() + .map(|method| { + json!({ + "type": method.kind, + "label": method.label, + }) + }) + .collect() +} + +fn provider_definition(provider_id: &str) -> Option { + provider_registry() + .into_iter() + .find(|provider| provider.id == provider_id) +} + +fn provider_auth_method(provider_id: &str, method_index: u32) -> Option { + provider_definition(provider_id) + .and_then(|provider| provider.auth_methods.get(method_index as usize).cloned()) +} + fn default_agent_id() -> AgentId { AgentId::Mock } @@ -595,10 +695,37 @@ fn default_agent_mode() -> &'static str { fn resolve_agent_from_model(provider_id: &str, model_id: &str) -> Option { if provider_id == OPENCODE_PROVIDER_ID { - AgentId::parse(model_id) - } else { - None + return AgentId::parse(model_id); } + if provider_id == PROVIDER_ANTHROPIC { + if model_id == AgentId::Amp.as_str() { + return Some(AgentId::Amp); + } + return Some(AgentId::Claude); + } + if provider_id == PROVIDER_OPENAI { + return Some(AgentId::Codex); + } + if provider_id == AgentId::Opencode.as_str() { + return Some(AgentId::Opencode); + } + None +} + +fn default_model_for_provider(provider_id: &str) -> Option<&'static str> { + if provider_id == OPENCODE_PROVIDER_ID { + return Some(OPENCODE_DEFAULT_MODEL_ID); + } + if provider_id == PROVIDER_ANTHROPIC { + return Some(AgentId::Claude.as_str()); + } + if provider_id == PROVIDER_OPENAI { + return Some(AgentId::Codex.as_str()); + } + if provider_id == AgentId::Opencode.as_str() { + return Some(AgentId::Opencode.as_str()); + } + None } fn normalize_agent_mode(agent: Option) -> String { @@ -621,6 +748,12 @@ async fn resolve_session_agent( .unwrap_or(OPENCODE_DEFAULT_MODEL_ID) .to_string(); let mut resolved_agent = resolve_agent_from_model(&provider_id, &model_id); + if resolved_agent.is_none() { + if let Some(default_model) = default_model_for_provider(&provider_id) { + model_id = default_model.to_string(); + resolved_agent = resolve_agent_from_model(&provider_id, &model_id); + } + } if resolved_agent.is_none() { provider_id = OPENCODE_PROVIDER_ID.to_string(); model_id = OPENCODE_DEFAULT_MODEL_ID.to_string(); @@ -3526,24 +3659,47 @@ async fn oc_question_reject( responses((status = 200)), tag = "opencode" )] -async fn oc_provider_list() -> impl IntoResponse { - let mut models = serde_json::Map::new(); - for agent in available_agent_ids() { - models.insert(agent.as_str().to_string(), model_summary_entry(agent)); +async fn oc_provider_list(State(state): State>) -> impl IntoResponse { + let credentials = match state.inner.session_manager().resolve_credentials().await { + Ok(credentials) => credentials, + Err(err) => return sandbox_error_response(err).into_response(), + }; + let mut connected = ProviderAuthStore::connected_providers(&credentials); + if !connected.is_empty() && !connected.iter().any(|id| id == OPENCODE_PROVIDER_ID) { + connected.push(OPENCODE_PROVIDER_ID.to_string()); } + + let registry = provider_registry(); + let mut providers = Vec::new(); + let mut registry_ids = Vec::new(); + for provider in ®istry { + registry_ids.push(provider.id); + providers.push(json!({ + "id": provider.id, + "name": provider.name, + "env": provider.env, + "models": provider_models_map(&provider.models), + })); + } + + for provider_id in connected.iter() { + if registry_ids.iter().any(|id| id == &provider_id.as_str()) { + continue; + } + providers.push(json!({ + "id": provider_id, + "name": provider_id, + "env": [], + "models": Value::Object(serde_json::Map::new()), + })); + } + let providers = json!({ - "all": [ - { - "id": OPENCODE_PROVIDER_ID, - "name": OPENCODE_PROVIDER_NAME, - "env": [], - "models": Value::Object(models), - } - ], + "all": providers, "default": { OPENCODE_PROVIDER_ID: OPENCODE_DEFAULT_MODEL_ID }, - "connected": [OPENCODE_PROVIDER_ID] + "connected": connected, }); (StatusCode::OK, Json(providers)) } @@ -3555,10 +3711,14 @@ async fn oc_provider_list() -> impl IntoResponse { tag = "opencode" )] async fn oc_provider_auth() -> impl IntoResponse { - let auth = json!({ - OPENCODE_PROVIDER_ID: [] - }); - (StatusCode::OK, Json(auth)) + let mut map = serde_json::Map::new(); + for provider in provider_registry() { + map.insert( + provider.id.to_string(), + json!(provider_auth_method_values(&provider.auth_methods)), + ); + } + (StatusCode::OK, Json(Value::Object(map))) } @@ -3566,16 +3726,32 @@ async fn oc_provider_auth() -> impl IntoResponse { post, path = "/provider/{providerID}/oauth/authorize", params(("providerID" = String, Path, description = "Provider ID")), + request_body = ProviderOauthRequest, responses((status = 200)), tag = "opencode" )] -async fn oc_provider_oauth_authorize(Path(provider_id): Path) -> impl IntoResponse { +async fn oc_provider_oauth_authorize( + Path(provider_id): Path, + Json(body): Json, +) -> impl IntoResponse { + let provider_id = provider_id.to_ascii_lowercase(); + let method_index = match body.method { + Some(method) => method, + None => return bad_request("method is required").into_response(), + }; + let method = match provider_auth_method(&provider_id, method_index) { + Some(method) => method, + None => return bad_request("invalid auth method").into_response(), + }; + if method.kind != "oauth" { + return bad_request("auth method is not oauth").into_response(); + } ( StatusCode::OK, Json(json!({ "url": format!("https://auth.local/{}/authorize", provider_id), "method": "auto", - "instructions": "stub", + "instructions": "Open the URL to authorize.", })), ) } @@ -3584,10 +3760,35 @@ async fn oc_provider_oauth_authorize(Path(provider_id): Path) -> impl In post, path = "/provider/{providerID}/oauth/callback", params(("providerID" = String, Path, description = "Provider ID")), + request_body = ProviderOauthCallbackRequest, responses((status = 200)), tag = "opencode" )] -async fn oc_provider_oauth_callback(Path(_provider_id): Path) -> impl IntoResponse { +async fn oc_provider_oauth_callback( + State(state): State>, + Path(provider_id): Path, + Json(body): Json, +) -> impl IntoResponse { + let provider_id = provider_id.to_ascii_lowercase(); + let method_index = match body.method { + Some(method) => method, + None => return bad_request("method is required").into_response(), + }; + let method = match provider_auth_method(&provider_id, method_index) { + Some(method) => method, + None => return bad_request("invalid auth method").into_response(), + }; + if method.kind != "oauth" { + return bad_request("auth method is not oauth").into_response(); + } + let Some(code) = body.code else { + return bad_request("code is required").into_response(); + }; + state + .inner + .session_manager() + .set_provider_auth(&provider_id, ProviderAuth::OAuth { access: code }) + .await; bool_ok(true) } @@ -3599,7 +3800,54 @@ async fn oc_provider_oauth_callback(Path(_provider_id): Path) -> impl In responses((status = 200)), tag = "opencode" )] -async fn oc_auth_set(Path(_provider_id): Path, Json(_body): Json) -> impl IntoResponse { +async fn oc_auth_set( + State(state): State>, + Path(provider_id): Path, + Json(body): Json, +) -> impl IntoResponse { + let provider_id = provider_id.to_ascii_lowercase(); + if provider_id.is_empty() { + return bad_request("providerID is required").into_response(); + } + let auth_type = body.get("type").and_then(Value::as_str); + let auth = match auth_type { + Some("api") => body + .get("key") + .and_then(Value::as_str) + .filter(|value| !value.is_empty()) + .map(|key| ProviderAuth::Api { + key: key.to_string(), + }), + Some("oauth") => body + .get("access") + .and_then(Value::as_str) + .filter(|value| !value.is_empty()) + .map(|access| ProviderAuth::OAuth { + access: access.to_string(), + }), + Some("wellknown") => { + let key = body.get("key").and_then(Value::as_str).unwrap_or(""); + let token = body.get("token").and_then(Value::as_str).unwrap_or(""); + if key.is_empty() || token.is_empty() { + None + } else { + Some(ProviderAuth::WellKnown { + key: key.to_string(), + token: token.to_string(), + }) + } + } + _ => None, + }; + + let Some(auth) = auth else { + return bad_request("invalid auth payload").into_response(); + }; + state + .inner + .session_manager() + .set_provider_auth(&provider_id, auth) + .await; bool_ok(true) } @@ -3610,7 +3858,15 @@ async fn oc_auth_set(Path(_provider_id): Path, Json(_body): Json) responses((status = 200)), tag = "opencode" )] -async fn oc_auth_remove(Path(_provider_id): Path) -> impl IntoResponse { +async fn oc_auth_remove( + State(state): State>, + Path(provider_id): Path, +) -> impl IntoResponse { + state + .inner + .session_manager() + .remove_provider_auth(&provider_id) + .await; bool_ok(true) } diff --git a/server/packages/sandbox-agent/src/provider_auth.rs b/server/packages/sandbox-agent/src/provider_auth.rs new file mode 100644 index 0000000..696015b --- /dev/null +++ b/server/packages/sandbox-agent/src/provider_auth.rs @@ -0,0 +1,119 @@ +use std::collections::HashMap; + +use sandbox_agent_agent_management::credentials::{AuthType, ExtractedCredentials, ProviderCredentials}; +use tokio::sync::Mutex; + +#[derive(Clone, Debug)] +pub enum ProviderAuth { + Api { key: String }, + OAuth { access: String }, + WellKnown { key: String, token: String }, +} + +#[derive(Clone, Debug)] +pub enum ProviderAuthOverride { + Set(ProviderAuth), + Remove, +} + +#[derive(Debug)] +pub struct ProviderAuthStore { + overrides: Mutex>, +} + +impl ProviderAuthStore { + pub fn new() -> Self { + Self { + overrides: Mutex::new(HashMap::new()), + } + } + + pub async fn set(&self, provider_id: &str, auth: ProviderAuth) { + let provider = normalize_provider_id(provider_id); + let mut overrides = self.overrides.lock().await; + overrides.insert(provider, ProviderAuthOverride::Set(auth)); + } + + pub async fn remove(&self, provider_id: &str) { + let provider = normalize_provider_id(provider_id); + let mut overrides = self.overrides.lock().await; + overrides.insert(provider, ProviderAuthOverride::Remove); + } + + pub async fn snapshot(&self) -> HashMap { + self.overrides.lock().await.clone() + } + + pub fn apply_overrides( + mut credentials: ExtractedCredentials, + overrides: HashMap, + ) -> ExtractedCredentials { + for (provider, override_value) in overrides { + match override_value { + ProviderAuthOverride::Set(auth) => { + let cred = provider_credentials(&provider, &auth); + match provider.as_str() { + "anthropic" => credentials.anthropic = Some(cred), + "openai" => credentials.openai = Some(cred), + _ => { + credentials.other.insert(provider.clone(), cred); + } + } + } + ProviderAuthOverride::Remove => match provider.as_str() { + "anthropic" => credentials.anthropic = None, + "openai" => credentials.openai = None, + _ => { + credentials.other.remove(&provider); + } + }, + } + } + credentials + } + + pub fn connected_providers(credentials: &ExtractedCredentials) -> Vec { + let mut connected = Vec::new(); + if let Some(cred) = &credentials.anthropic { + connected.push(cred.provider.clone()); + } + if let Some(cred) = &credentials.openai { + connected.push(cred.provider.clone()); + } + for key in credentials.other.keys() { + connected.push(key.clone()); + } + connected.sort(); + connected.dedup(); + connected + } +} + +fn provider_credentials(provider: &str, auth: &ProviderAuth) -> ProviderCredentials { + ProviderCredentials { + api_key: auth_key(auth).to_string(), + source: "opencode".to_string(), + auth_type: auth_type(auth), + provider: provider.to_string(), + } +} + +fn auth_type(auth: &ProviderAuth) -> AuthType { + match auth { + ProviderAuth::Api { .. } => AuthType::ApiKey, + ProviderAuth::OAuth { .. } => AuthType::Oauth, + ProviderAuth::WellKnown { .. } => AuthType::ApiKey, + } +} + +fn auth_key(auth: &ProviderAuth) -> &str { + match auth { + ProviderAuth::Api { key } => key, + ProviderAuth::OAuth { access } => access, + ProviderAuth::WellKnown { token, .. } => token, + } +} + +fn normalize_provider_id(provider_id: &str) -> String { + provider_id.trim().to_ascii_lowercase() +} diff --git a/server/packages/sandbox-agent/src/router.rs b/server/packages/sandbox-agent/src/router.rs index 3ca437a..6a31ba0 100644 --- a/server/packages/sandbox-agent/src/router.rs +++ b/server/packages/sandbox-agent/src/router.rs @@ -40,6 +40,7 @@ use utoipa::{Modify, OpenApi, ToSchema}; use crate::agent_server_logs::AgentServerLogs; use crate::opencode_compat::{build_opencode_router, OpenCodeAppState}; +use crate::provider_auth::{ProviderAuth, ProviderAuthStore}; use crate::ui; use sandbox_agent_agent_management::agents::{ AgentError as ManagerError, AgentId, AgentManager, InstallOptions, SpawnOptions, StreamingSpawn, @@ -818,6 +819,7 @@ pub(crate) struct SessionManager { sessions: Mutex>, server_manager: Arc, http_client: Client, + provider_auth: Arc, } /// Shared Codex app-server process that handles multiple sessions via JSON-RPC. @@ -1538,6 +1540,7 @@ impl SessionManager { sessions: Mutex::new(Vec::new()), server_manager, http_client: Client::new(), + provider_auth: Arc::new(ProviderAuthStore::new()), } } @@ -1562,6 +1565,27 @@ impl SessionManager { logs.read_stderr() } + pub(crate) async fn set_provider_auth(&self, provider_id: &str, auth: ProviderAuth) { + self.provider_auth.set(provider_id, auth).await; + } + + pub(crate) async fn remove_provider_auth(&self, provider_id: &str) { + self.provider_auth.remove(provider_id).await; + } + + pub(crate) async fn resolve_credentials(&self) -> Result { + let overrides = self.provider_auth.snapshot().await; + let credentials = tokio::task::spawn_blocking(move || { + let options = CredentialExtractionOptions::new(); + extract_all_credentials(&options) + }) + .await + .map_err(|err| SandboxError::StreamError { + message: err.to_string(), + })?; + Ok(ProviderAuthStore::apply_overrides(credentials, overrides)) + } + pub(crate) async fn create_session( self: &Arc, session_id: String, @@ -1737,15 +1761,7 @@ impl SessionManager { } else { None }; - let credentials = tokio::task::spawn_blocking(move || { - let options = CredentialExtractionOptions::new(); - extract_all_credentials(&options) - }) - .await - .map_err(|err| SandboxError::StreamError { - message: err.to_string(), - })?; - + let credentials = self.resolve_credentials().await?; let spawn_options = build_spawn_options(&session_snapshot, prompt.clone(), credentials); let agent_id = session_snapshot.agent; let spawn_result = diff --git a/server/packages/sandbox-agent/tests/opencode-compat/provider-auth.test.ts b/server/packages/sandbox-agent/tests/opencode-compat/provider-auth.test.ts new file mode 100644 index 0000000..f0d3726 --- /dev/null +++ b/server/packages/sandbox-agent/tests/opencode-compat/provider-auth.test.ts @@ -0,0 +1,51 @@ +/** + * Tests for OpenCode-compatible provider auth endpoints. + */ + +import { describe, it, expect, beforeAll, beforeEach, afterEach } from "vitest"; +import { createOpencodeClient, type OpencodeClient } from "@opencode-ai/sdk"; +import { spawnSandboxAgent, buildSandboxAgent, type SandboxAgentHandle } from "./helpers/spawn"; + +describe("OpenCode-compatible Provider Auth API", () => { + let handle: SandboxAgentHandle; + let client: OpencodeClient; + + beforeAll(async () => { + await buildSandboxAgent(); + }); + + beforeEach(async () => { + handle = await spawnSandboxAgent({ opencodeCompat: true }); + client = createOpencodeClient({ + baseUrl: `${handle.baseUrl}/opencode`, + headers: { Authorization: `Bearer ${handle.token}` }, + }); + }); + + afterEach(async () => { + await handle?.dispose(); + }); + + it("should set/remove credentials and update connected providers", async () => { + const initial = await client.provider.list(); + const providers = initial.data?.all ?? []; + expect(providers.some((provider) => provider.id === "anthropic")).toBe(true); + + const setResponse = await client.auth.set({ + path: { providerID: "anthropic" }, + body: { type: "api", key: "sk-test" }, + }); + expect(setResponse.data).toBe(true); + + const afterSet = await client.provider.list(); + expect(afterSet.data?.connected?.includes("anthropic")).toBe(true); + + const removeResponse = await client.auth.remove({ + path: { providerID: "anthropic" }, + }); + expect(removeResponse.data).toBe(true); + + const afterRemove = await client.provider.list(); + expect(afterRemove.data?.connected?.includes("anthropic")).toBe(false); + }); +}); diff --git a/target b/target new file mode 120000 index 0000000..3d6ad8c --- /dev/null +++ b/target @@ -0,0 +1 @@ +/home/nathan/sandbox-agent/target \ No newline at end of file