mirror of
https://github.com/harivansh-afk/sandbox-agent.git
synced 2026-04-20 16:05:18 +00:00
feat: implement provider auth lifecycle and endpoints
This commit is contained in:
parent
7378abee46
commit
312c3a0c8b
9 changed files with 481 additions and 34 deletions
1
.turbo
Symbolic link
1
.turbo
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
/home/nathan/sandbox-agent/.turbo
|
||||||
1
dist
Symbolic link
1
dist
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
/home/nathan/sandbox-agent/dist
|
||||||
1
node_modules
Symbolic link
1
node_modules
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
/home/nathan/sandbox-agent/node_modules
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
mod agent_server_logs;
|
mod agent_server_logs;
|
||||||
pub mod credentials;
|
pub mod credentials;
|
||||||
pub mod opencode_compat;
|
pub mod opencode_compat;
|
||||||
|
pub mod provider_auth;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod server_logs;
|
pub mod server_logs;
|
||||||
pub mod telemetry;
|
pub mod telemetry;
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ use tokio::time::interval;
|
||||||
use utoipa::{IntoParams, OpenApi, ToSchema};
|
use utoipa::{IntoParams, OpenApi, ToSchema};
|
||||||
|
|
||||||
use crate::router::{AppState, CreateSessionRequest, PermissionReply};
|
use crate::router::{AppState, CreateSessionRequest, PermissionReply};
|
||||||
|
use crate::provider_auth::{ProviderAuth, ProviderAuthStore};
|
||||||
use sandbox_agent_error::SandboxError;
|
use sandbox_agent_error::SandboxError;
|
||||||
use sandbox_agent_agent_management::agents::AgentId;
|
use sandbox_agent_agent_management::agents::AgentId;
|
||||||
use sandbox_agent_universal_agent_schema::{
|
use sandbox_agent_universal_agent_schema::{
|
||||||
|
|
@ -41,6 +42,23 @@ const OPENCODE_PROVIDER_ID: &str = "sandbox-agent";
|
||||||
const OPENCODE_PROVIDER_NAME: &str = "Sandbox Agent";
|
const OPENCODE_PROVIDER_NAME: &str = "Sandbox Agent";
|
||||||
const OPENCODE_DEFAULT_MODEL_ID: &str = "mock";
|
const OPENCODE_DEFAULT_MODEL_ID: &str = "mock";
|
||||||
const OPENCODE_DEFAULT_AGENT_MODE: &str = "build";
|
const OPENCODE_DEFAULT_AGENT_MODE: &str = "build";
|
||||||
|
const PROVIDER_ANTHROPIC: &str = "anthropic";
|
||||||
|
const PROVIDER_OPENAI: &str = "openai";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct ProviderAuthMethod {
|
||||||
|
kind: &'static str,
|
||||||
|
label: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct ProviderDefinition {
|
||||||
|
id: &'static str,
|
||||||
|
name: &'static str,
|
||||||
|
env: &'static [&'static str],
|
||||||
|
models: Vec<AgentId>,
|
||||||
|
auth_methods: Vec<ProviderAuthMethod>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct OpenCodeCompatConfig {
|
struct OpenCodeCompatConfig {
|
||||||
|
|
@ -555,6 +573,19 @@ struct PermissionGlobalReplyRequest {
|
||||||
reply: Option<String>,
|
reply: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct ProviderOauthRequest {
|
||||||
|
method: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct ProviderOauthCallbackRequest {
|
||||||
|
method: Option<u32>,
|
||||||
|
code: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct QuestionReplyBody {
|
struct QuestionReplyBody {
|
||||||
|
|
@ -585,6 +616,75 @@ fn available_agent_ids() -> Vec<AgentId> {
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn provider_registry() -> Vec<ProviderDefinition> {
|
||||||
|
vec![
|
||||||
|
ProviderDefinition {
|
||||||
|
id: OPENCODE_PROVIDER_ID,
|
||||||
|
name: OPENCODE_PROVIDER_NAME,
|
||||||
|
env: &[],
|
||||||
|
models: available_agent_ids(),
|
||||||
|
auth_methods: Vec::new(),
|
||||||
|
},
|
||||||
|
ProviderDefinition {
|
||||||
|
id: PROVIDER_ANTHROPIC,
|
||||||
|
name: "Anthropic",
|
||||||
|
env: &["ANTHROPIC_API_KEY"],
|
||||||
|
models: vec![AgentId::Claude, AgentId::Amp],
|
||||||
|
auth_methods: vec![
|
||||||
|
ProviderAuthMethod {
|
||||||
|
kind: "api",
|
||||||
|
label: "API Key",
|
||||||
|
},
|
||||||
|
ProviderAuthMethod {
|
||||||
|
kind: "oauth",
|
||||||
|
label: "OAuth",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
ProviderDefinition {
|
||||||
|
id: PROVIDER_OPENAI,
|
||||||
|
name: "OpenAI",
|
||||||
|
env: &["OPENAI_API_KEY"],
|
||||||
|
models: vec![AgentId::Codex],
|
||||||
|
auth_methods: vec![ProviderAuthMethod {
|
||||||
|
kind: "api",
|
||||||
|
label: "API Key",
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_models_map(models: &[AgentId]) -> Value {
|
||||||
|
let mut map = serde_json::Map::new();
|
||||||
|
for agent in models {
|
||||||
|
map.insert(agent.as_str().to_string(), model_summary_entry(*agent));
|
||||||
|
}
|
||||||
|
Value::Object(map)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_auth_method_values(methods: &[ProviderAuthMethod]) -> Vec<Value> {
|
||||||
|
methods
|
||||||
|
.iter()
|
||||||
|
.map(|method| {
|
||||||
|
json!({
|
||||||
|
"type": method.kind,
|
||||||
|
"label": method.label,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_definition(provider_id: &str) -> Option<ProviderDefinition> {
|
||||||
|
provider_registry()
|
||||||
|
.into_iter()
|
||||||
|
.find(|provider| provider.id == provider_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_auth_method(provider_id: &str, method_index: u32) -> Option<ProviderAuthMethod> {
|
||||||
|
provider_definition(provider_id)
|
||||||
|
.and_then(|provider| provider.auth_methods.get(method_index as usize).cloned())
|
||||||
|
}
|
||||||
|
|
||||||
fn default_agent_id() -> AgentId {
|
fn default_agent_id() -> AgentId {
|
||||||
AgentId::Mock
|
AgentId::Mock
|
||||||
}
|
}
|
||||||
|
|
@ -595,10 +695,37 @@ fn default_agent_mode() -> &'static str {
|
||||||
|
|
||||||
fn resolve_agent_from_model(provider_id: &str, model_id: &str) -> Option<AgentId> {
|
fn resolve_agent_from_model(provider_id: &str, model_id: &str) -> Option<AgentId> {
|
||||||
if provider_id == OPENCODE_PROVIDER_ID {
|
if provider_id == OPENCODE_PROVIDER_ID {
|
||||||
AgentId::parse(model_id)
|
return AgentId::parse(model_id);
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
if provider_id == PROVIDER_ANTHROPIC {
|
||||||
|
if model_id == AgentId::Amp.as_str() {
|
||||||
|
return Some(AgentId::Amp);
|
||||||
|
}
|
||||||
|
return Some(AgentId::Claude);
|
||||||
|
}
|
||||||
|
if provider_id == PROVIDER_OPENAI {
|
||||||
|
return Some(AgentId::Codex);
|
||||||
|
}
|
||||||
|
if provider_id == AgentId::Opencode.as_str() {
|
||||||
|
return Some(AgentId::Opencode);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_model_for_provider(provider_id: &str) -> Option<&'static str> {
|
||||||
|
if provider_id == OPENCODE_PROVIDER_ID {
|
||||||
|
return Some(OPENCODE_DEFAULT_MODEL_ID);
|
||||||
|
}
|
||||||
|
if provider_id == PROVIDER_ANTHROPIC {
|
||||||
|
return Some(AgentId::Claude.as_str());
|
||||||
|
}
|
||||||
|
if provider_id == PROVIDER_OPENAI {
|
||||||
|
return Some(AgentId::Codex.as_str());
|
||||||
|
}
|
||||||
|
if provider_id == AgentId::Opencode.as_str() {
|
||||||
|
return Some(AgentId::Opencode.as_str());
|
||||||
|
}
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_agent_mode(agent: Option<String>) -> String {
|
fn normalize_agent_mode(agent: Option<String>) -> String {
|
||||||
|
|
@ -621,6 +748,12 @@ async fn resolve_session_agent(
|
||||||
.unwrap_or(OPENCODE_DEFAULT_MODEL_ID)
|
.unwrap_or(OPENCODE_DEFAULT_MODEL_ID)
|
||||||
.to_string();
|
.to_string();
|
||||||
let mut resolved_agent = resolve_agent_from_model(&provider_id, &model_id);
|
let mut resolved_agent = resolve_agent_from_model(&provider_id, &model_id);
|
||||||
|
if resolved_agent.is_none() {
|
||||||
|
if let Some(default_model) = default_model_for_provider(&provider_id) {
|
||||||
|
model_id = default_model.to_string();
|
||||||
|
resolved_agent = resolve_agent_from_model(&provider_id, &model_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
if resolved_agent.is_none() {
|
if resolved_agent.is_none() {
|
||||||
provider_id = OPENCODE_PROVIDER_ID.to_string();
|
provider_id = OPENCODE_PROVIDER_ID.to_string();
|
||||||
model_id = OPENCODE_DEFAULT_MODEL_ID.to_string();
|
model_id = OPENCODE_DEFAULT_MODEL_ID.to_string();
|
||||||
|
|
@ -3526,24 +3659,47 @@ async fn oc_question_reject(
|
||||||
responses((status = 200)),
|
responses((status = 200)),
|
||||||
tag = "opencode"
|
tag = "opencode"
|
||||||
)]
|
)]
|
||||||
async fn oc_provider_list() -> impl IntoResponse {
|
async fn oc_provider_list(State(state): State<Arc<OpenCodeAppState>>) -> impl IntoResponse {
|
||||||
let mut models = serde_json::Map::new();
|
let credentials = match state.inner.session_manager().resolve_credentials().await {
|
||||||
for agent in available_agent_ids() {
|
Ok(credentials) => credentials,
|
||||||
models.insert(agent.as_str().to_string(), model_summary_entry(agent));
|
Err(err) => return sandbox_error_response(err).into_response(),
|
||||||
|
};
|
||||||
|
let mut connected = ProviderAuthStore::connected_providers(&credentials);
|
||||||
|
if !connected.is_empty() && !connected.iter().any(|id| id == OPENCODE_PROVIDER_ID) {
|
||||||
|
connected.push(OPENCODE_PROVIDER_ID.to_string());
|
||||||
}
|
}
|
||||||
let providers = json!({
|
|
||||||
"all": [
|
let registry = provider_registry();
|
||||||
{
|
let mut providers = Vec::new();
|
||||||
"id": OPENCODE_PROVIDER_ID,
|
let mut registry_ids = Vec::new();
|
||||||
"name": OPENCODE_PROVIDER_NAME,
|
for provider in ®istry {
|
||||||
|
registry_ids.push(provider.id);
|
||||||
|
providers.push(json!({
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"env": provider.env,
|
||||||
|
"models": provider_models_map(&provider.models),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for provider_id in connected.iter() {
|
||||||
|
if registry_ids.iter().any(|id| id == &provider_id.as_str()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
providers.push(json!({
|
||||||
|
"id": provider_id,
|
||||||
|
"name": provider_id,
|
||||||
"env": [],
|
"env": [],
|
||||||
"models": Value::Object(models),
|
"models": Value::Object(serde_json::Map::new()),
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
],
|
|
||||||
|
let providers = json!({
|
||||||
|
"all": providers,
|
||||||
"default": {
|
"default": {
|
||||||
OPENCODE_PROVIDER_ID: OPENCODE_DEFAULT_MODEL_ID
|
OPENCODE_PROVIDER_ID: OPENCODE_DEFAULT_MODEL_ID
|
||||||
},
|
},
|
||||||
"connected": [OPENCODE_PROVIDER_ID]
|
"connected": connected,
|
||||||
});
|
});
|
||||||
(StatusCode::OK, Json(providers))
|
(StatusCode::OK, Json(providers))
|
||||||
}
|
}
|
||||||
|
|
@ -3555,10 +3711,14 @@ async fn oc_provider_list() -> impl IntoResponse {
|
||||||
tag = "opencode"
|
tag = "opencode"
|
||||||
)]
|
)]
|
||||||
async fn oc_provider_auth() -> impl IntoResponse {
|
async fn oc_provider_auth() -> impl IntoResponse {
|
||||||
let auth = json!({
|
let mut map = serde_json::Map::new();
|
||||||
OPENCODE_PROVIDER_ID: []
|
for provider in provider_registry() {
|
||||||
});
|
map.insert(
|
||||||
(StatusCode::OK, Json(auth))
|
provider.id.to_string(),
|
||||||
|
json!(provider_auth_method_values(&provider.auth_methods)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
(StatusCode::OK, Json(Value::Object(map)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3566,16 +3726,32 @@ async fn oc_provider_auth() -> impl IntoResponse {
|
||||||
post,
|
post,
|
||||||
path = "/provider/{providerID}/oauth/authorize",
|
path = "/provider/{providerID}/oauth/authorize",
|
||||||
params(("providerID" = String, Path, description = "Provider ID")),
|
params(("providerID" = String, Path, description = "Provider ID")),
|
||||||
|
request_body = ProviderOauthRequest,
|
||||||
responses((status = 200)),
|
responses((status = 200)),
|
||||||
tag = "opencode"
|
tag = "opencode"
|
||||||
)]
|
)]
|
||||||
async fn oc_provider_oauth_authorize(Path(provider_id): Path<String>) -> impl IntoResponse {
|
async fn oc_provider_oauth_authorize(
|
||||||
|
Path(provider_id): Path<String>,
|
||||||
|
Json(body): Json<ProviderOauthRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let provider_id = provider_id.to_ascii_lowercase();
|
||||||
|
let method_index = match body.method {
|
||||||
|
Some(method) => method,
|
||||||
|
None => return bad_request("method is required").into_response(),
|
||||||
|
};
|
||||||
|
let method = match provider_auth_method(&provider_id, method_index) {
|
||||||
|
Some(method) => method,
|
||||||
|
None => return bad_request("invalid auth method").into_response(),
|
||||||
|
};
|
||||||
|
if method.kind != "oauth" {
|
||||||
|
return bad_request("auth method is not oauth").into_response();
|
||||||
|
}
|
||||||
(
|
(
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"url": format!("https://auth.local/{}/authorize", provider_id),
|
"url": format!("https://auth.local/{}/authorize", provider_id),
|
||||||
"method": "auto",
|
"method": "auto",
|
||||||
"instructions": "stub",
|
"instructions": "Open the URL to authorize.",
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -3584,10 +3760,35 @@ async fn oc_provider_oauth_authorize(Path(provider_id): Path<String>) -> impl In
|
||||||
post,
|
post,
|
||||||
path = "/provider/{providerID}/oauth/callback",
|
path = "/provider/{providerID}/oauth/callback",
|
||||||
params(("providerID" = String, Path, description = "Provider ID")),
|
params(("providerID" = String, Path, description = "Provider ID")),
|
||||||
|
request_body = ProviderOauthCallbackRequest,
|
||||||
responses((status = 200)),
|
responses((status = 200)),
|
||||||
tag = "opencode"
|
tag = "opencode"
|
||||||
)]
|
)]
|
||||||
async fn oc_provider_oauth_callback(Path(_provider_id): Path<String>) -> impl IntoResponse {
|
async fn oc_provider_oauth_callback(
|
||||||
|
State(state): State<Arc<OpenCodeAppState>>,
|
||||||
|
Path(provider_id): Path<String>,
|
||||||
|
Json(body): Json<ProviderOauthCallbackRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let provider_id = provider_id.to_ascii_lowercase();
|
||||||
|
let method_index = match body.method {
|
||||||
|
Some(method) => method,
|
||||||
|
None => return bad_request("method is required").into_response(),
|
||||||
|
};
|
||||||
|
let method = match provider_auth_method(&provider_id, method_index) {
|
||||||
|
Some(method) => method,
|
||||||
|
None => return bad_request("invalid auth method").into_response(),
|
||||||
|
};
|
||||||
|
if method.kind != "oauth" {
|
||||||
|
return bad_request("auth method is not oauth").into_response();
|
||||||
|
}
|
||||||
|
let Some(code) = body.code else {
|
||||||
|
return bad_request("code is required").into_response();
|
||||||
|
};
|
||||||
|
state
|
||||||
|
.inner
|
||||||
|
.session_manager()
|
||||||
|
.set_provider_auth(&provider_id, ProviderAuth::OAuth { access: code })
|
||||||
|
.await;
|
||||||
bool_ok(true)
|
bool_ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3599,7 +3800,54 @@ async fn oc_provider_oauth_callback(Path(_provider_id): Path<String>) -> impl In
|
||||||
responses((status = 200)),
|
responses((status = 200)),
|
||||||
tag = "opencode"
|
tag = "opencode"
|
||||||
)]
|
)]
|
||||||
async fn oc_auth_set(Path(_provider_id): Path<String>, Json(_body): Json<Value>) -> impl IntoResponse {
|
async fn oc_auth_set(
|
||||||
|
State(state): State<Arc<OpenCodeAppState>>,
|
||||||
|
Path(provider_id): Path<String>,
|
||||||
|
Json(body): Json<Value>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let provider_id = provider_id.to_ascii_lowercase();
|
||||||
|
if provider_id.is_empty() {
|
||||||
|
return bad_request("providerID is required").into_response();
|
||||||
|
}
|
||||||
|
let auth_type = body.get("type").and_then(Value::as_str);
|
||||||
|
let auth = match auth_type {
|
||||||
|
Some("api") => body
|
||||||
|
.get("key")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(|key| ProviderAuth::Api {
|
||||||
|
key: key.to_string(),
|
||||||
|
}),
|
||||||
|
Some("oauth") => body
|
||||||
|
.get("access")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(|access| ProviderAuth::OAuth {
|
||||||
|
access: access.to_string(),
|
||||||
|
}),
|
||||||
|
Some("wellknown") => {
|
||||||
|
let key = body.get("key").and_then(Value::as_str).unwrap_or("");
|
||||||
|
let token = body.get("token").and_then(Value::as_str).unwrap_or("");
|
||||||
|
if key.is_empty() || token.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(ProviderAuth::WellKnown {
|
||||||
|
key: key.to_string(),
|
||||||
|
token: token.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(auth) = auth else {
|
||||||
|
return bad_request("invalid auth payload").into_response();
|
||||||
|
};
|
||||||
|
state
|
||||||
|
.inner
|
||||||
|
.session_manager()
|
||||||
|
.set_provider_auth(&provider_id, auth)
|
||||||
|
.await;
|
||||||
bool_ok(true)
|
bool_ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3610,7 +3858,15 @@ async fn oc_auth_set(Path(_provider_id): Path<String>, Json(_body): Json<Value>)
|
||||||
responses((status = 200)),
|
responses((status = 200)),
|
||||||
tag = "opencode"
|
tag = "opencode"
|
||||||
)]
|
)]
|
||||||
async fn oc_auth_remove(Path(_provider_id): Path<String>) -> impl IntoResponse {
|
async fn oc_auth_remove(
|
||||||
|
State(state): State<Arc<OpenCodeAppState>>,
|
||||||
|
Path(provider_id): Path<String>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
state
|
||||||
|
.inner
|
||||||
|
.session_manager()
|
||||||
|
.remove_provider_auth(&provider_id)
|
||||||
|
.await;
|
||||||
bool_ok(true)
|
bool_ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
119
server/packages/sandbox-agent/src/provider_auth.rs
Normal file
119
server/packages/sandbox-agent/src/provider_auth.rs
Normal file
|
|
@ -0,0 +1,119 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use sandbox_agent_agent_management::credentials::{AuthType, ExtractedCredentials, ProviderCredentials};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum ProviderAuth {
|
||||||
|
Api { key: String },
|
||||||
|
OAuth { access: String },
|
||||||
|
WellKnown { key: String, token: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum ProviderAuthOverride {
|
||||||
|
Set(ProviderAuth),
|
||||||
|
Remove,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ProviderAuthStore {
|
||||||
|
overrides: Mutex<HashMap<String, ProviderAuthOverride>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderAuthStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
overrides: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set(&self, provider_id: &str, auth: ProviderAuth) {
|
||||||
|
let provider = normalize_provider_id(provider_id);
|
||||||
|
let mut overrides = self.overrides.lock().await;
|
||||||
|
overrides.insert(provider, ProviderAuthOverride::Set(auth));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove(&self, provider_id: &str) {
|
||||||
|
let provider = normalize_provider_id(provider_id);
|
||||||
|
let mut overrides = self.overrides.lock().await;
|
||||||
|
overrides.insert(provider, ProviderAuthOverride::Remove);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn snapshot(&self) -> HashMap<String, ProviderAuthOverride> {
|
||||||
|
self.overrides.lock().await.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_overrides(
|
||||||
|
mut credentials: ExtractedCredentials,
|
||||||
|
overrides: HashMap<String, ProviderAuthOverride>,
|
||||||
|
) -> ExtractedCredentials {
|
||||||
|
for (provider, override_value) in overrides {
|
||||||
|
match override_value {
|
||||||
|
ProviderAuthOverride::Set(auth) => {
|
||||||
|
let cred = provider_credentials(&provider, &auth);
|
||||||
|
match provider.as_str() {
|
||||||
|
"anthropic" => credentials.anthropic = Some(cred),
|
||||||
|
"openai" => credentials.openai = Some(cred),
|
||||||
|
_ => {
|
||||||
|
credentials.other.insert(provider.clone(), cred);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ProviderAuthOverride::Remove => match provider.as_str() {
|
||||||
|
"anthropic" => credentials.anthropic = None,
|
||||||
|
"openai" => credentials.openai = None,
|
||||||
|
_ => {
|
||||||
|
credentials.other.remove(&provider);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connected_providers(credentials: &ExtractedCredentials) -> Vec<String> {
|
||||||
|
let mut connected = Vec::new();
|
||||||
|
if let Some(cred) = &credentials.anthropic {
|
||||||
|
connected.push(cred.provider.clone());
|
||||||
|
}
|
||||||
|
if let Some(cred) = &credentials.openai {
|
||||||
|
connected.push(cred.provider.clone());
|
||||||
|
}
|
||||||
|
for key in credentials.other.keys() {
|
||||||
|
connected.push(key.clone());
|
||||||
|
}
|
||||||
|
connected.sort();
|
||||||
|
connected.dedup();
|
||||||
|
connected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_credentials(provider: &str, auth: &ProviderAuth) -> ProviderCredentials {
|
||||||
|
ProviderCredentials {
|
||||||
|
api_key: auth_key(auth).to_string(),
|
||||||
|
source: "opencode".to_string(),
|
||||||
|
auth_type: auth_type(auth),
|
||||||
|
provider: provider.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_type(auth: &ProviderAuth) -> AuthType {
|
||||||
|
match auth {
|
||||||
|
ProviderAuth::Api { .. } => AuthType::ApiKey,
|
||||||
|
ProviderAuth::OAuth { .. } => AuthType::Oauth,
|
||||||
|
ProviderAuth::WellKnown { .. } => AuthType::ApiKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_key(auth: &ProviderAuth) -> &str {
|
||||||
|
match auth {
|
||||||
|
ProviderAuth::Api { key } => key,
|
||||||
|
ProviderAuth::OAuth { access } => access,
|
||||||
|
ProviderAuth::WellKnown { token, .. } => token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_provider_id(provider_id: &str) -> String {
|
||||||
|
provider_id.trim().to_ascii_lowercase()
|
||||||
|
}
|
||||||
|
|
@ -40,6 +40,7 @@ use utoipa::{Modify, OpenApi, ToSchema};
|
||||||
|
|
||||||
use crate::agent_server_logs::AgentServerLogs;
|
use crate::agent_server_logs::AgentServerLogs;
|
||||||
use crate::opencode_compat::{build_opencode_router, OpenCodeAppState};
|
use crate::opencode_compat::{build_opencode_router, OpenCodeAppState};
|
||||||
|
use crate::provider_auth::{ProviderAuth, ProviderAuthStore};
|
||||||
use crate::ui;
|
use crate::ui;
|
||||||
use sandbox_agent_agent_management::agents::{
|
use sandbox_agent_agent_management::agents::{
|
||||||
AgentError as ManagerError, AgentId, AgentManager, InstallOptions, SpawnOptions, StreamingSpawn,
|
AgentError as ManagerError, AgentId, AgentManager, InstallOptions, SpawnOptions, StreamingSpawn,
|
||||||
|
|
@ -818,6 +819,7 @@ pub(crate) struct SessionManager {
|
||||||
sessions: Mutex<Vec<SessionState>>,
|
sessions: Mutex<Vec<SessionState>>,
|
||||||
server_manager: Arc<AgentServerManager>,
|
server_manager: Arc<AgentServerManager>,
|
||||||
http_client: Client,
|
http_client: Client,
|
||||||
|
provider_auth: Arc<ProviderAuthStore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Shared Codex app-server process that handles multiple sessions via JSON-RPC.
|
/// Shared Codex app-server process that handles multiple sessions via JSON-RPC.
|
||||||
|
|
@ -1538,6 +1540,7 @@ impl SessionManager {
|
||||||
sessions: Mutex::new(Vec::new()),
|
sessions: Mutex::new(Vec::new()),
|
||||||
server_manager,
|
server_manager,
|
||||||
http_client: Client::new(),
|
http_client: Client::new(),
|
||||||
|
provider_auth: Arc::new(ProviderAuthStore::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1562,6 +1565,27 @@ impl SessionManager {
|
||||||
logs.read_stderr()
|
logs.read_stderr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn set_provider_auth(&self, provider_id: &str, auth: ProviderAuth) {
|
||||||
|
self.provider_auth.set(provider_id, auth).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn remove_provider_auth(&self, provider_id: &str) {
|
||||||
|
self.provider_auth.remove(provider_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn resolve_credentials(&self) -> Result<ExtractedCredentials, SandboxError> {
|
||||||
|
let overrides = self.provider_auth.snapshot().await;
|
||||||
|
let credentials = tokio::task::spawn_blocking(move || {
|
||||||
|
let options = CredentialExtractionOptions::new();
|
||||||
|
extract_all_credentials(&options)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|err| SandboxError::StreamError {
|
||||||
|
message: err.to_string(),
|
||||||
|
})?;
|
||||||
|
Ok(ProviderAuthStore::apply_overrides(credentials, overrides))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) async fn create_session(
|
pub(crate) async fn create_session(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
session_id: String,
|
session_id: String,
|
||||||
|
|
@ -1737,15 +1761,7 @@ impl SessionManager {
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let credentials = tokio::task::spawn_blocking(move || {
|
let credentials = self.resolve_credentials().await?;
|
||||||
let options = CredentialExtractionOptions::new();
|
|
||||||
extract_all_credentials(&options)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|err| SandboxError::StreamError {
|
|
||||||
message: err.to_string(),
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let spawn_options = build_spawn_options(&session_snapshot, prompt.clone(), credentials);
|
let spawn_options = build_spawn_options(&session_snapshot, prompt.clone(), credentials);
|
||||||
let agent_id = session_snapshot.agent;
|
let agent_id = session_snapshot.agent;
|
||||||
let spawn_result =
|
let spawn_result =
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
/**
|
||||||
|
* Tests for OpenCode-compatible provider auth endpoints.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeAll, beforeEach, afterEach } from "vitest";
|
||||||
|
import { createOpencodeClient, type OpencodeClient } from "@opencode-ai/sdk";
|
||||||
|
import { spawnSandboxAgent, buildSandboxAgent, type SandboxAgentHandle } from "./helpers/spawn";
|
||||||
|
|
||||||
|
describe("OpenCode-compatible Provider Auth API", () => {
|
||||||
|
let handle: SandboxAgentHandle;
|
||||||
|
let client: OpencodeClient;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
await buildSandboxAgent();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
handle = await spawnSandboxAgent({ opencodeCompat: true });
|
||||||
|
client = createOpencodeClient({
|
||||||
|
baseUrl: `${handle.baseUrl}/opencode`,
|
||||||
|
headers: { Authorization: `Bearer ${handle.token}` },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(async () => {
|
||||||
|
await handle?.dispose();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should set/remove credentials and update connected providers", async () => {
|
||||||
|
const initial = await client.provider.list();
|
||||||
|
const providers = initial.data?.all ?? [];
|
||||||
|
expect(providers.some((provider) => provider.id === "anthropic")).toBe(true);
|
||||||
|
|
||||||
|
const setResponse = await client.auth.set({
|
||||||
|
path: { providerID: "anthropic" },
|
||||||
|
body: { type: "api", key: "sk-test" },
|
||||||
|
});
|
||||||
|
expect(setResponse.data).toBe(true);
|
||||||
|
|
||||||
|
const afterSet = await client.provider.list();
|
||||||
|
expect(afterSet.data?.connected?.includes("anthropic")).toBe(true);
|
||||||
|
|
||||||
|
const removeResponse = await client.auth.remove({
|
||||||
|
path: { providerID: "anthropic" },
|
||||||
|
});
|
||||||
|
expect(removeResponse.data).toBe(true);
|
||||||
|
|
||||||
|
const afterRemove = await client.provider.list();
|
||||||
|
expect(afterRemove.data?.connected?.includes("anthropic")).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
1
target
Symbolic link
1
target
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
/home/nathan/sandbox-agent/target
|
||||||
Loading…
Add table
Add a link
Reference in a new issue