user-owned DAVs

This commit is contained in:
Harivansh Rathi 2026-04-01 20:26:15 -04:00
parent ca5014750b
commit 0a3234d617
35 changed files with 732 additions and 777 deletions

View file

@ -1,8 +1,6 @@
package main
import (
"errors"
"strings"
"time"
)
@ -11,10 +9,6 @@ type appConfig struct {
nextcloudBaseURL string
statePath string
dbPath string
clientToken string
nodeBootstrapToken string
davAuthSecret string
davCredentialTTL time.Duration
sessionTTL time.Duration
registrationEnabled bool
corsOrigin string
@ -28,21 +22,6 @@ type app struct {
}
func newApp(config appConfig, startedAt time.Time) (*app, error) {
config.clientToken = strings.TrimSpace(config.clientToken)
config.nodeBootstrapToken = strings.TrimSpace(config.nodeBootstrapToken)
if config.nodeBootstrapToken == "" {
return nil, errors.New("node bootstrap token is required")
}
config.davAuthSecret = strings.TrimSpace(config.davAuthSecret)
if config.davAuthSecret == "" {
return nil, errors.New("dav auth secret is required")
}
if config.davCredentialTTL <= 0 {
return nil, errors.New("dav credential ttl must be greater than 0")
}
var s store
var err error
if config.dbPath != "" {
@ -91,6 +70,7 @@ type nasNode struct {
LastSeenAt string `json:"lastSeenAt"`
DirectAddress *string `json:"directAddress"`
RelayAddress *string `json:"relayAddress"`
OwnerID string `json:"-"`
}
type storageExport struct {
@ -102,6 +82,7 @@ type storageExport struct {
Protocols []string `json:"protocols"`
CapacityBytes *int64 `json:"capacityBytes"`
Tags []string `json:"tags"`
OwnerID string `json:"-"`
}
type mountProfile struct {

View file

@ -196,21 +196,25 @@ func TestAuthSessionUsedForClientEndpoints(t *testing.T) {
}
}
func TestAuthStaticTokenFallback(t *testing.T) {
func TestAuthSessionIsTheOnlyClientAuthPath(t *testing.T) {
t.Parallel()
_, server := newTestSQLiteApp(t, appConfig{
version: "test-version",
clientToken: "static-fallback-token",
version: "test-version",
registrationEnabled: true,
})
defer server.Close()
// Static token should work for client endpoints.
exports := getJSONAuth[[]storageExport](t, server.Client(), "static-fallback-token", server.URL+"/api/v1/exports")
reg := postJSONAuthCreated[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
Username: "sessiononly",
Password: "password123",
})
exports := getJSONAuth[[]storageExport](t, server.Client(), reg.Token, server.URL+"/api/v1/exports")
if exports == nil {
t.Fatal("expected exports list, got nil")
}
// Wrong token should fail.
getStatusWithAuth(t, server.Client(), "static-fallback-token", server.URL+"/api/v1/exports", http.StatusUnauthorized)
getStatusWithAuth(t, server.Client(), "wrong", server.URL+"/api/v1/exports", http.StatusUnauthorized)
}

View file

@ -1,6 +1,7 @@
package main
import (
"errors"
"log"
"net/http"
"strings"
@ -25,44 +26,52 @@ func main() {
}
func newAppFromEnv(startedAt time.Time) (*app, error) {
nodeBootstrapToken, err := requiredEnv("BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN")
if err != nil {
return nil, err
}
davAuthSecret, err := requiredEnv("BETTERNAS_DAV_AUTH_SECRET")
if err != nil {
return nil, err
}
davCredentialTTL, err := parseRequiredDurationEnv("BETTERNAS_DAV_CREDENTIAL_TTL")
if err != nil {
return nil, err
}
var sessionTTL time.Duration
rawSessionTTL := strings.TrimSpace(env("BETTERNAS_SESSION_TTL", "720h"))
if rawSessionTTL != "" {
sessionTTL, err = time.ParseDuration(rawSessionTTL)
parsedSessionTTL, err := time.ParseDuration(rawSessionTTL)
if err != nil {
return nil, err
}
sessionTTL = parsedSessionTTL
}
return newApp(
app, err := newApp(
appConfig{
version: env("BETTERNAS_VERSION", "0.1.0-dev"),
nextcloudBaseURL: env("NEXTCLOUD_BASE_URL", ""),
statePath: env("BETTERNAS_CONTROL_PLANE_STATE_PATH", ".state/control-plane/state.json"),
dbPath: env("BETTERNAS_CONTROL_PLANE_DB_PATH", ""),
clientToken: env("BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN", ""),
nodeBootstrapToken: nodeBootstrapToken,
davAuthSecret: davAuthSecret,
davCredentialTTL: davCredentialTTL,
dbPath: env("BETTERNAS_CONTROL_PLANE_DB_PATH", ".state/control-plane/betternas.db"),
sessionTTL: sessionTTL,
registrationEnabled: env("BETTERNAS_REGISTRATION_ENABLED", "true") == "true",
corsOrigin: env("BETTERNAS_CORS_ORIGIN", ""),
},
startedAt,
)
if err != nil {
return nil, err
}
if err := seedDefaultUserFromEnv(app); err != nil {
return nil, err
}
return app, nil
}
func seedDefaultUserFromEnv(app *app) error {
username := strings.TrimSpace(env("BETTERNAS_USERNAME", ""))
password := env("BETTERNAS_PASSWORD", "")
if username == "" || password == "" {
return nil
}
if _, err := app.store.createUser(username, password); err != nil {
if errors.Is(err, errUsernameTaken) {
_, authErr := app.store.authenticateUser(username, password)
return authErr
}
return err
}
return nil
}

View file

@ -1,89 +1,12 @@
package main
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"time"
)
const mountCredentialModeBasicAuth = "basic-auth"
// mountCredentialUsernameTokenBytes controls the random token size in mount
// credential usernames (e.g. "mount-<token>"). The username is also embedded
// inside the signed password payload, so longer tokens produce longer
// passwords. macOS WebDAVFS truncates Basic Auth passwords at 255 bytes,
// which corrupts the HMAC signature and causes auth failures. 24 bytes
// (32 base64url chars) keeps the total password under 250 characters with
// margin for longer node IDs and mount paths.
const mountCredentialUsernameTokenBytes = 24
type signedMountCredentialClaims struct {
Version int `json:"v"`
NodeID string `json:"nodeId"`
MountPath string `json:"mountPath"`
Username string `json:"username"`
Readonly bool `json:"readonly"`
ExpiresAt string `json:"expiresAt"`
}
func issueMountCredential(secret string, nodeID string, mountPath string, readonly bool, issuedAt time.Time, ttl time.Duration) (string, mountCredential, error) {
credentialID, err := newOpaqueToken()
if err != nil {
return "", mountCredential{}, err
}
usernameToken, err := newMountCredentialUsernameToken()
if err != nil {
return "", mountCredential{}, err
}
claims := signedMountCredentialClaims{
Version: 1,
NodeID: nodeID,
MountPath: mountPath,
Username: "mount-" + usernameToken,
Readonly: readonly,
ExpiresAt: issuedAt.UTC().Add(ttl).Format(time.RFC3339),
}
password, err := signMountCredentialClaims(secret, claims)
if err != nil {
return "", mountCredential{}, err
}
return "mount-" + credentialID, mountCredential{
func buildAccountMountCredential(username string) mountCredential {
return mountCredential{
Mode: mountCredentialModeBasicAuth,
Username: claims.Username,
Password: password,
ExpiresAt: claims.ExpiresAt,
}, nil
}
func newMountCredentialUsernameToken() (string, error) {
raw := make([]byte, mountCredentialUsernameTokenBytes)
if _, err := rand.Read(raw); err != nil {
return "", fmt.Errorf("generate mount credential username token: %w", err)
Username: username,
Password: "",
ExpiresAt: "",
}
return base64.RawURLEncoding.EncodeToString(raw), nil
}
func signMountCredentialClaims(secret string, claims signedMountCredentialClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("encode mount credential claims: %w", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
signature := signMountCredentialPayload(secret, encodedPayload)
return encodedPayload + "." + signature, nil
}
func signMountCredentialPayload(secret string, encodedPayload string) string {
mac := hmac.New(sha256.New, []byte(secret))
_, _ = mac.Write([]byte(encodedPayload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View file

@ -31,10 +31,7 @@ var (
nodeAgentBinaryErr error
)
const (
runtimeDAVAuthSecret = "runtime-dav-auth-secret"
runtimeDAVCredentialTTL = "1h"
)
const runtimeUsername = "runtime-user"
func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
exportDir := t.TempDir()
@ -47,7 +44,7 @@ func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
nodeAgent := startNodeAgentBinaryWithExports(t, controlPlane.baseURL, []string{exportDir}, "machine-runtime-1")
client := &http.Client{Timeout: 2 * time.Second}
exports := waitForExportsByPath(t, client, controlPlane.baseURL+"/api/v1/exports", []string{exportDir})
exports := waitForExportsByPath(t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/exports", []string{exportDir})
export := exports[exportDir]
if export.ID != "dev-export" {
t.Fatalf("expected export ID %q, got %q", "dev-export", export.ID)
@ -56,7 +53,7 @@ func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
t.Fatalf("expected mountPath %q, got %q", defaultWebDAVPath, export.MountPath)
}
mount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
mount := postJSONAuth[mountProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
ExportID: export.ID,
})
if mount.MountURL != nodeAgent.baseURL+defaultWebDAVPath {
@ -66,11 +63,11 @@ func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
t.Fatalf("expected mount credential mode %q, got %q", mountCredentialModeBasicAuth, mount.Credential.Mode)
}
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", mount.MountURL, mount.Credential.Username, mount.Credential.Password, http.StatusMultiStatus)
assertMountedFileContentsWithBasicAuth(t, client, mount.MountURL+"README.txt", mount.Credential.Username, mount.Credential.Password, "betterNAS export\n")
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", mount.MountURL, controlPlane.username, controlPlane.password, http.StatusMultiStatus)
assertMountedFileContentsWithBasicAuth(t, client, mount.MountURL+"README.txt", controlPlane.username, controlPlane.password, "betterNAS export\n")
cloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "runtime-user",
cloud := postJSONAuth[cloudProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: controlPlane.userID,
ExportID: export.ID,
Provider: "nextcloud",
})
@ -97,12 +94,12 @@ func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
firstMountPath := nodeAgentMountPathForExport(firstExportDir, 2)
secondMountPath := nodeAgentMountPathForExport(secondExportDir, 2)
exports := waitForExportsByPath(t, client, controlPlane.baseURL+"/api/v1/exports", []string{firstExportDir, secondExportDir})
exports := waitForExportsByPath(t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/exports", []string{firstExportDir, secondExportDir})
firstExport := exports[firstExportDir]
secondExport := exports[secondExportDir]
firstMount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: firstExport.ID})
secondMount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: secondExport.ID})
firstMount := postJSONAuth[mountProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: firstExport.ID})
secondMount := postJSONAuth[mountProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: secondExport.ID})
if firstMount.MountURL == secondMount.MountURL {
t.Fatalf("expected distinct runtime mount URLs, got %q", firstMount.MountURL)
}
@ -113,18 +110,18 @@ func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
t.Fatalf("expected second runtime mount URL %q, got %q", nodeAgent.baseURL+secondMountPath, secondMount.MountURL)
}
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", firstMount.MountURL, firstMount.Credential.Username, firstMount.Credential.Password, http.StatusMultiStatus)
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", secondMount.MountURL, secondMount.Credential.Username, secondMount.Credential.Password, http.StatusMultiStatus)
assertMountedFileContentsWithBasicAuth(t, client, firstMount.MountURL+"README.txt", firstMount.Credential.Username, firstMount.Credential.Password, "first runtime export\n")
assertMountedFileContentsWithBasicAuth(t, client, secondMount.MountURL+"README.txt", secondMount.Credential.Username, secondMount.Credential.Password, "second runtime export\n")
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", firstMount.MountURL, controlPlane.username, controlPlane.password, http.StatusMultiStatus)
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", secondMount.MountURL, controlPlane.username, controlPlane.password, http.StatusMultiStatus)
assertMountedFileContentsWithBasicAuth(t, client, firstMount.MountURL+"README.txt", controlPlane.username, controlPlane.password, "first runtime export\n")
assertMountedFileContentsWithBasicAuth(t, client, secondMount.MountURL+"README.txt", controlPlane.username, controlPlane.password, "second runtime export\n")
firstCloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "runtime-user",
firstCloud := postJSONAuth[cloudProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: controlPlane.userID,
ExportID: firstExport.ID,
Provider: "nextcloud",
})
secondCloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "runtime-user",
secondCloud := postJSONAuth[cloudProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: controlPlane.userID,
ExportID: secondExport.ID,
Provider: "nextcloud",
})
@ -140,8 +137,12 @@ func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
}
type runningBinary struct {
baseURL string
logPath string
baseURL string
logPath string
sessionToken string
username string
password string
userID string
}
func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL string) runningBinary {
@ -149,7 +150,7 @@ func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL stri
port := reserveTCPPort(t)
logPath := filepath.Join(t.TempDir(), "control-plane.log")
statePath := filepath.Join(t.TempDir(), "control-plane-state.json")
dbPath := filepath.Join(t.TempDir(), "control-plane.db")
logFile, err := os.Create(logPath)
if err != nil {
t.Fatalf("create control-plane log file: %v", err)
@ -162,11 +163,8 @@ func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL stri
"PORT="+port,
"BETTERNAS_VERSION="+version,
"NEXTCLOUD_BASE_URL="+nextcloudBaseURL,
"BETTERNAS_CONTROL_PLANE_STATE_PATH="+statePath,
"BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN="+testClientToken,
"BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN="+testNodeBootstrapToken,
"BETTERNAS_DAV_AUTH_SECRET="+runtimeDAVAuthSecret,
"BETTERNAS_DAV_CREDENTIAL_TTL="+runtimeDAVCredentialTTL,
"BETTERNAS_CONTROL_PLANE_DB_PATH="+dbPath,
"BETTERNAS_REGISTRATION_ENABLED=true",
)
cmd.Stdout = logFile
cmd.Stderr = logFile
@ -183,11 +181,16 @@ func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL stri
baseURL := fmt.Sprintf("http://127.0.0.1:%s", port)
waitForHTTPStatus(t, baseURL+"/health", waitDone, logPath, http.StatusOK)
session := registerRuntimeUser(t, &http.Client{Timeout: 2 * time.Second}, baseURL)
registerProcessCleanup(t, ctx, cancel, cmd, waitDone, logFile, logPath, "control-plane")
return runningBinary{
baseURL: baseURL,
logPath: logPath,
baseURL: baseURL,
logPath: logPath,
sessionToken: session.Token,
username: runtimeUsername,
password: testPassword,
userID: session.User.ID,
}
}
@ -197,7 +200,6 @@ func startNodeAgentBinaryWithExports(t *testing.T, controlPlaneBaseURL string, e
port := reserveTCPPort(t)
baseURL := fmt.Sprintf("http://127.0.0.1:%s", port)
logPath := filepath.Join(t.TempDir(), "node-agent.log")
nodeTokenPath := filepath.Join(t.TempDir(), "node-token")
logFile, err := os.Create(logPath)
if err != nil {
t.Fatalf("create node-agent log file: %v", err)
@ -215,12 +217,11 @@ func startNodeAgentBinaryWithExports(t *testing.T, controlPlaneBaseURL string, e
"PORT="+port,
"BETTERNAS_EXPORT_PATHS_JSON="+string(rawExportPaths),
"BETTERNAS_CONTROL_PLANE_URL="+controlPlaneBaseURL,
"BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN="+testNodeBootstrapToken,
"BETTERNAS_NODE_TOKEN_PATH="+nodeTokenPath,
"BETTERNAS_USERNAME="+runtimeUsername,
"BETTERNAS_PASSWORD="+testPassword,
"BETTERNAS_NODE_MACHINE_ID="+machineID,
"BETTERNAS_NODE_DISPLAY_NAME="+machineID,
"BETTERNAS_NODE_DIRECT_ADDRESS="+baseURL,
"BETTERNAS_DAV_AUTH_SECRET="+runtimeDAVAuthSecret,
"BETTERNAS_VERSION=runtime-test-version",
)
cmd.Stdout = logFile
@ -245,12 +246,12 @@ func startNodeAgentBinaryWithExports(t *testing.T, controlPlaneBaseURL string, e
}
}
func waitForExportsByPath(t *testing.T, client *http.Client, endpoint string, expectedPaths []string) map[string]storageExport {
func waitForExportsByPath(t *testing.T, client *http.Client, token string, endpoint string, expectedPaths []string) map[string]storageExport {
t.Helper()
deadline := time.Now().Add(10 * time.Second)
for time.Now().Before(deadline) {
exports := getJSONAuth[[]storageExport](t, client, testClientToken, endpoint)
exports := getJSONAuth[[]storageExport](t, client, token, endpoint)
exportsByPath := exportsByPath(exports)
allPresent := true
for _, expectedPath := range expectedPaths {
@ -269,6 +270,15 @@ func waitForExportsByPath(t *testing.T, client *http.Client, endpoint string, ex
return nil
}
func registerRuntimeUser(t *testing.T, client *http.Client, baseURL string) authLoginResponse {
t.Helper()
return postJSONAuthCreated[authLoginResponse](t, client, "", baseURL+"/api/v1/auth/register", authRegisterRequest{
Username: runtimeUsername,
Password: testPassword,
})
}
func buildControlPlaneBinary(t *testing.T) string {
t.Helper()

View file

@ -2,7 +2,6 @@ package main
import (
"bytes"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
@ -21,12 +20,12 @@ var (
errMountTargetUnavailable = errors.New("mount target is not available")
errNodeIDMismatch = errors.New("node id path and body must match")
errNodeNotFound = errors.New("node not found")
errNodeOwnedByAnotherUser = errors.New("node is already owned by another user")
)
const (
authorizationHeader = "Authorization"
controlPlaneNodeTokenKey = "X-BetterNAS-Node-Token"
bearerScheme = "Bearer"
authorizationHeader = "Authorization"
bearerScheme = "Bearer"
)
func (a *app) handler() http.Handler {
@ -76,6 +75,11 @@ func (a *app) handleVersion(w http.ResponseWriter, _ *http.Request) {
}
func (a *app) handleNodeRegister(w http.ResponseWriter, r *http.Request) {
currentUser, ok := a.requireSessionUser(w, r)
if !ok {
return
}
request, err := decodeNodeRegistrationRequest(w, r)
if err != nil {
writeDecodeError(w, err)
@ -87,23 +91,25 @@ func (a *app) handleNodeRegister(w http.ResponseWriter, r *http.Request) {
return
}
if !a.authorizeNodeRegistration(w, r, request.MachineID) {
return
}
result, err := a.store.registerNode(request, a.now())
result, err := a.store.registerNode(currentUser.ID, request, a.now())
if err != nil {
if errors.Is(err, errNodeOwnedByAnotherUser) {
http.Error(w, err.Error(), http.StatusConflict)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if result.IssuedNodeToken != "" {
w.Header().Set(controlPlaneNodeTokenKey, result.IssuedNodeToken)
}
writeJSON(w, http.StatusOK, result.Node)
}
func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
currentUser, ok := a.requireSessionUser(w, r)
if !ok {
return
}
nodeID := r.PathValue("nodeId")
var request nodeHeartbeatRequest
@ -121,11 +127,7 @@ func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
return
}
if !a.authorizeNode(w, r, nodeID) {
return
}
if err := a.store.recordHeartbeat(nodeID, request); err != nil {
if err := a.store.recordHeartbeat(nodeID, currentUser.ID, request); err != nil {
statusCode := http.StatusInternalServerError
if errors.Is(err, errNodeNotFound) {
statusCode = http.StatusNotFound
@ -138,6 +140,11 @@ func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
}
func (a *app) handleNodeExports(w http.ResponseWriter, r *http.Request) {
currentUser, ok := a.requireSessionUser(w, r)
if !ok {
return
}
nodeID := r.PathValue("nodeId")
request, err := decodeNodeExportsRequest(w, r)
@ -151,11 +158,7 @@ func (a *app) handleNodeExports(w http.ResponseWriter, r *http.Request) {
return
}
if !a.authorizeNode(w, r, nodeID) {
return
}
exports, err := a.store.upsertExports(nodeID, request)
exports, err := a.store.upsertExports(nodeID, currentUser.ID, request)
if err != nil {
statusCode := http.StatusInternalServerError
if errors.Is(err, errNodeNotFound) {
@ -169,15 +172,17 @@ func (a *app) handleNodeExports(w http.ResponseWriter, r *http.Request) {
}
func (a *app) handleExportsList(w http.ResponseWriter, r *http.Request) {
if !a.requireClientAuth(w, r) {
currentUser, ok := a.requireSessionUser(w, r)
if !ok {
return
}
writeJSON(w, http.StatusOK, a.store.listExports())
writeJSON(w, http.StatusOK, a.store.listExports(currentUser.ID))
}
func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
if !a.requireClientAuth(w, r) {
currentUser, ok := a.requireSessionUser(w, r)
if !ok {
return
}
@ -192,8 +197,8 @@ func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
return
}
context, ok := a.store.exportContext(request.ExportID)
if !ok {
context, found := a.store.exportContext(request.ExportID, currentUser.ID)
if !found {
http.Error(w, errExportNotFound.Error(), http.StatusNotFound)
return
}
@ -204,32 +209,20 @@ func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
return
}
credentialID, credential, err := issueMountCredential(
a.config.davAuthSecret,
context.node.ID,
mountProfilePathForExport(context.export.MountPath),
false,
a.now(),
a.config.davCredentialTTL,
)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, mountProfile{
ID: credentialID,
ID: context.export.ID,
ExportID: context.export.ID,
Protocol: "webdav",
DisplayName: context.export.Label,
MountURL: mountURL,
Readonly: false,
Credential: credential,
Credential: buildAccountMountCredential(currentUser.Username),
})
}
func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
if !a.requireClientAuth(w, r) {
currentUser, ok := a.requireSessionUser(w, r)
if !ok {
return
}
@ -244,8 +237,8 @@ func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
return
}
context, ok := a.store.exportContext(request.ExportID)
if !ok {
context, found := a.store.exportContext(request.ExportID, currentUser.ID)
if !found {
http.Error(w, errExportNotFound.Error(), http.StatusNotFound)
return
}
@ -257,7 +250,7 @@ func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, cloudProfile{
ID: fmt.Sprintf("cloud-%s-%s", request.UserID, context.export.ID),
ID: fmt.Sprintf("cloud-%s-%s", currentUser.ID, context.export.ID),
ExportID: context.export.ID,
Provider: "nextcloud",
BaseURL: baseURL,
@ -1034,71 +1027,22 @@ func corsMiddleware(allowedOrigin string, next http.Handler) http.Handler {
})
}
// --- client auth ---
// --- session auth ---
func (a *app) requireClientAuth(w http.ResponseWriter, r *http.Request) bool {
func (a *app) requireSessionUser(w http.ResponseWriter, r *http.Request) (user, bool) {
presentedToken, ok := bearerToken(r)
if !ok {
writeUnauthorized(w)
return false
return user{}, false
}
// Session-based auth (SQLite).
if _, err := a.store.validateSession(presentedToken); err == nil {
return true
}
// Fall back to static client token for backwards compatibility.
if a.config.clientToken != "" && secureStringEquals(a.config.clientToken, presentedToken) {
return true
}
writeUnauthorized(w)
return false
}
func (a *app) authorizeNodeRegistration(w http.ResponseWriter, r *http.Request, machineID string) bool {
presentedToken, ok := bearerToken(r)
if !ok {
currentUser, err := a.store.validateSession(presentedToken)
if err != nil {
writeUnauthorized(w)
return false
return user{}, false
}
authState, exists := a.store.nodeAuthByMachineID(machineID)
if !exists || strings.TrimSpace(authState.TokenHash) == "" {
if !secureStringEquals(a.config.nodeBootstrapToken, presentedToken) {
writeUnauthorized(w)
return false
}
return true
}
if !tokenHashMatches(authState.TokenHash, presentedToken) {
writeUnauthorized(w)
return false
}
return true
}
func (a *app) authorizeNode(w http.ResponseWriter, r *http.Request, nodeID string) bool {
presentedToken, ok := bearerToken(r)
if !ok {
writeUnauthorized(w)
return false
}
authState, exists := a.store.nodeAuthByID(nodeID)
if !exists {
http.Error(w, errNodeNotFound.Error(), http.StatusNotFound)
return false
}
if strings.TrimSpace(authState.TokenHash) == "" || !tokenHashMatches(authState.TokenHash, presentedToken) {
writeUnauthorized(w)
return false
}
return true
return currentUser, true
}
func bearerToken(r *http.Request) (string, bool) {
@ -1124,11 +1068,3 @@ func writeUnauthorized(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", bearerScheme)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func secureStringEquals(expected string, actual string) bool {
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
}
func tokenHashMatches(expectedHash string, token string) bool {
return secureStringEquals(expectedHash, hashOpaqueToken(token))
}

View file

@ -3,6 +3,7 @@ package main
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
@ -15,8 +16,9 @@ import (
var testControlPlaneNow = time.Date(2025, time.January, 1, 12, 0, 0, 0, time.UTC)
const (
testPassword = "password123"
testClientToken = "test-client-token"
testNodeBootstrapToken = "test-node-bootstrap-token"
testNodeBootstrapToken = "test-node-session-token"
)
type registeredNode struct {
@ -94,7 +96,7 @@ func TestControlPlaneRegistrationProfilesAndHeartbeat(t *testing.T) {
RelayAddress: &relayAddress,
})
if registration.NodeToken == "" {
t.Fatal("expected node registration to return a node token")
t.Fatal("expected node registration to preserve the session token")
}
syncedExports := syncNodeExports(t, server.Client(), registration.NodeToken, server.URL+"/api/v1/nodes/"+registration.Node.ID+"/exports", nodeExportsRequest{
@ -169,14 +171,14 @@ func TestControlPlaneRegistrationProfilesAndHeartbeat(t *testing.T) {
if mount.Credential.Mode != mountCredentialModeBasicAuth {
t.Fatalf("expected credential mode %q, got %q", mountCredentialModeBasicAuth, mount.Credential.Mode)
}
if mount.Credential.Username == "" {
t.Fatal("expected mount credential username to be set")
if mount.Credential.Username != "fixture" {
t.Fatalf("expected mount credential username %q, got %q", "fixture", mount.Credential.Username)
}
if mount.Credential.Password == "" {
t.Fatal("expected mount credential password to be set")
if mount.Credential.Password != "" {
t.Fatalf("expected mount credential password to be blank, got %q", mount.Credential.Password)
}
if mount.Credential.ExpiresAt == "" {
t.Fatal("expected mount credential expiry to be set")
if mount.Credential.ExpiresAt != "" {
t.Fatalf("expected mount credential expiry to be blank, got %q", mount.Credential.ExpiresAt)
}
cloud := postJSONAuth[cloudProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
@ -231,7 +233,7 @@ func TestControlPlaneExportSyncReconcilesExportsAndKeepsStableIDs(t *testing.T)
RelayAddress: nil,
})
putJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
putJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
Exports: []storageExportInput{
{
Label: "Docs",
@ -285,7 +287,7 @@ func TestControlPlaneExportSyncReconcilesExportsAndKeepsStableIDs(t *testing.T)
RelayAddress: nil,
})
putJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
putJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
Exports: []storageExportInput{
{
Label: "Docs v2",
@ -330,8 +332,8 @@ func TestControlPlaneExportSyncReconcilesExportsAndKeepsStableIDs(t *testing.T)
if secondRegistration.Node.ID != firstRegistration.Node.ID {
t.Fatalf("expected re-registration to keep node ID %q, got %q", firstRegistration.Node.ID, secondRegistration.Node.ID)
}
if secondRegistration.NodeToken != "" {
t.Fatalf("expected re-registration to keep the existing node token, got %q", secondRegistration.NodeToken)
if secondRegistration.NodeToken != firstRegistration.NodeToken {
t.Fatalf("expected re-registration to keep the existing session token %q, got %q", firstRegistration.NodeToken, secondRegistration.NodeToken)
}
updatedExports := exportsByPath(getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports"))
@ -539,12 +541,12 @@ func TestControlPlaneCloudProfilesRequireConfiguredBaseURLAndExistingExport(t *t
func TestControlPlanePersistsRegistryAcrossAppRestart(t *testing.T) {
t.Parallel()
statePath := filepath.Join(t.TempDir(), "control-plane-state.json")
dbPath := filepath.Join(t.TempDir(), "control-plane.db")
directAddress := "http://nas.local:8090"
_, firstServer := newTestControlPlaneServer(t, appConfig{
version: "test-version",
statePath: statePath,
version: "test-version",
dbPath: dbPath,
})
registration := registerNode(t, firstServer.Client(), firstServer.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-persisted",
@ -566,8 +568,8 @@ func TestControlPlanePersistsRegistryAcrossAppRestart(t *testing.T) {
firstServer.Close()
_, secondServer := newTestControlPlaneServer(t, appConfig{
version: "test-version",
statePath: statePath,
version: "test-version",
dbPath: dbPath,
})
defer secondServer.Close()
@ -656,15 +658,12 @@ func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
if err := json.NewDecoder(response.Body).Decode(&node); err != nil {
t.Fatalf("decode registration response: %v", err)
}
nodeToken := strings.TrimSpace(response.Header.Get(controlPlaneNodeTokenKey))
if nodeToken == "" {
t.Fatal("expected node registration to return a node token")
}
nodeToken := testNodeBootstrapToken
if node.ID != "dev-node" {
t.Fatalf("expected node ID %q, got %q", "dev-node", node.ID)
}
putJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+node.ID+"/exports", nodeExportsRequest{
putJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+node.ID+"/exports", nodeExportsRequest{
Exports: []storageExportInput{{
Label: "Docs",
Path: "/srv/docs",
@ -716,7 +715,7 @@ func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
},
}, http.StatusBadRequest)
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
postJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: node.ID,
Status: "online",
LastSeenAt: "2025-01-02T03:04:05Z",
@ -765,21 +764,12 @@ func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
func newTestControlPlaneServer(t *testing.T, config appConfig) (*app, *httptest.Server) {
t.Helper()
if config.dbPath == "" {
config.dbPath = filepath.Join(t.TempDir(), "test.db")
}
if config.version == "" {
config.version = "test-version"
}
if config.clientToken == "" {
config.clientToken = testClientToken
}
if config.nodeBootstrapToken == "" {
config.nodeBootstrapToken = testNodeBootstrapToken
}
if config.davAuthSecret == "" {
config.davAuthSecret = "test-dav-auth-secret"
}
if config.davCredentialTTL == 0 {
config.davCredentialTTL = time.Hour
}
app, err := newApp(config, testControlPlaneNow)
if err != nil {
@ -788,11 +778,46 @@ func newTestControlPlaneServer(t *testing.T, config appConfig) (*app, *httptest.
app.now = func() time.Time {
return testControlPlaneNow
}
seedDefaultSessionUser(t, app)
server := httptest.NewServer(app.handler())
return app, server
}
func seedDefaultSessionUser(t *testing.T, app *app) {
t.Helper()
u, err := app.store.createUser("fixture", testPassword)
if err != nil && !errors.Is(err, errUsernameTaken) {
t.Fatalf("seed default test user: %v", err)
}
if errors.Is(err, errUsernameTaken) {
u, err = app.store.authenticateUser("fixture", testPassword)
if err != nil {
t.Fatalf("authenticate seeded test user: %v", err)
}
}
sqliteStore, ok := app.store.(*sqliteStore)
if !ok {
return
}
createdAt := time.Now().UTC().Format(time.RFC3339)
expiresAt := time.Now().UTC().Add(24 * time.Hour).Format(time.RFC3339)
for _, token := range []string{testClientToken, testNodeBootstrapToken} {
if _, err := sqliteStore.db.Exec(
"INSERT OR REPLACE INTO sessions (token, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
token,
u.ID,
createdAt,
expiresAt,
); err != nil {
t.Fatalf("seed session %s: %v", token, err)
}
}
}
func exportsByPath(exports []storageExport) map[string]storageExport {
byPath := make(map[string]storageExport, len(exports))
for _, export := range exports {
@ -820,10 +845,19 @@ func registerNode(t *testing.T, client *http.Client, endpoint string, token stri
return registeredNode{
Node: node,
NodeToken: strings.TrimSpace(response.Header.Get(controlPlaneNodeTokenKey)),
NodeToken: strings.TrimSpace(token),
}
}
func registerSessionUser(t *testing.T, client *http.Client, baseURL string, username string) authLoginResponse {
t.Helper()
return postJSONAuthCreated[authLoginResponse](t, client, "", baseURL+"/api/v1/auth/register", authRegisterRequest{
Username: username,
Password: testPassword,
})
}
func syncNodeExports(t *testing.T, client *http.Client, token string, endpoint string, payload nodeExportsRequest) []storageExport {
t.Helper()

View file

@ -17,8 +17,8 @@ import (
)
var (
errUsernameTaken = errors.New("username already taken")
errInvalidLogin = errors.New("invalid username or password")
errUsernameTaken = errors.New("username already taken")
errInvalidLogin = errors.New("invalid username or password")
errSessionExpired = errors.New("session expired or invalid")
)
@ -32,6 +32,7 @@ INSERT OR IGNORE INTO ordinals (name, value) VALUES ('node', 0), ('export', 0);
CREATE TABLE IF NOT EXISTS nodes (
id TEXT PRIMARY KEY,
machine_id TEXT NOT NULL UNIQUE,
owner_id TEXT REFERENCES users(id),
display_name TEXT NOT NULL DEFAULT '',
agent_version TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL DEFAULT 'online',
@ -48,6 +49,7 @@ CREATE TABLE IF NOT EXISTS node_tokens (
CREATE TABLE IF NOT EXISTS exports (
id TEXT PRIMARY KEY,
node_id TEXT NOT NULL REFERENCES nodes(id),
owner_id TEXT REFERENCES users(id),
label TEXT NOT NULL DEFAULT '',
path TEXT NOT NULL,
mount_path TEXT NOT NULL DEFAULT '',
@ -101,10 +103,40 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) {
db.Close()
return nil, fmt.Errorf("initialize database schema: %w", err)
}
if err := migrateSQLiteSchema(db); err != nil {
db.Close()
return nil, err
}
return &sqliteStore{db: db}, nil
}
func migrateSQLiteSchema(db *sql.DB) error {
migrations := []string{
"ALTER TABLE nodes ADD COLUMN owner_id TEXT REFERENCES users(id)",
"ALTER TABLE exports ADD COLUMN owner_id TEXT REFERENCES users(id)",
}
for _, statement := range migrations {
if _, err := db.Exec(statement); err != nil && !strings.Contains(err.Error(), "duplicate column name") {
return fmt.Errorf("run sqlite migration %q: %w", statement, err)
}
}
if _, err := db.Exec(`
UPDATE exports
SET owner_id = (
SELECT owner_id
FROM nodes
WHERE nodes.id = exports.node_id
)
WHERE owner_id IS NULL
`); err != nil {
return fmt.Errorf("backfill export owners: %w", err)
}
return nil
}
func (s *sqliteStore) nextOrdinal(tx *sql.Tx, name string) (int, error) {
var value int
err := tx.QueryRow("UPDATE ordinals SET value = value + 1 WHERE name = ? RETURNING value", name).Scan(&value)
@ -128,7 +160,7 @@ func ordinalToExportID(ordinal int) string {
return fmt.Sprintf("dev-export-%d", ordinal)
}
func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
func (s *sqliteStore) registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
tx, err := s.db.Begin()
if err != nil {
return nodeRegistrationResult{}, fmt.Errorf("begin transaction: %w", err)
@ -137,7 +169,8 @@ func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt
// Check if machine already registered.
var nodeID string
err = tx.QueryRow("SELECT id FROM nodes WHERE machine_id = ?", request.MachineID).Scan(&nodeID)
var existingOwnerID sql.NullString
err = tx.QueryRow("SELECT id, owner_id FROM nodes WHERE machine_id = ?", request.MachineID).Scan(&nodeID, &existingOwnerID)
if err == sql.ErrNoRows {
ordinal, err := s.nextOrdinal(tx, "node")
if err != nil {
@ -146,57 +179,40 @@ func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt
nodeID = ordinalToNodeID(ordinal)
} else if err != nil {
return nodeRegistrationResult{}, fmt.Errorf("lookup node by machine_id: %w", err)
} else if existingOwnerID.Valid && strings.TrimSpace(existingOwnerID.String) != "" && existingOwnerID.String != ownerID {
return nodeRegistrationResult{}, errNodeOwnedByAnotherUser
}
// Upsert node.
_, err = tx.Exec(`
INSERT INTO nodes (id, machine_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address)
VALUES (?, ?, ?, ?, 'online', ?, ?, ?)
INSERT INTO nodes (id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address)
VALUES (?, ?, ?, ?, ?, 'online', ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
owner_id = excluded.owner_id,
display_name = excluded.display_name,
agent_version = excluded.agent_version,
status = 'online',
last_seen_at = excluded.last_seen_at,
direct_address = excluded.direct_address,
relay_address = excluded.relay_address
`, nodeID, request.MachineID, request.DisplayName, request.AgentVersion,
`, nodeID, request.MachineID, ownerID, request.DisplayName, request.AgentVersion,
registeredAt.UTC().Format(time.RFC3339),
nullableString(request.DirectAddress), nullableString(request.RelayAddress))
if err != nil {
return nodeRegistrationResult{}, fmt.Errorf("upsert node: %w", err)
}
// Issue token if none exists.
var issuedNodeToken string
var existingHash sql.NullString
_ = tx.QueryRow("SELECT token_hash FROM node_tokens WHERE node_id = ?", nodeID).Scan(&existingHash)
if !existingHash.Valid || strings.TrimSpace(existingHash.String) == "" {
nodeToken, err := newOpaqueToken()
if err != nil {
return nodeRegistrationResult{}, err
}
_, err = tx.Exec(
"INSERT OR REPLACE INTO node_tokens (node_id, token_hash) VALUES (?, ?)",
nodeID, hashOpaqueToken(nodeToken))
if err != nil {
return nodeRegistrationResult{}, fmt.Errorf("store node token: %w", err)
}
issuedNodeToken = nodeToken
}
if err := tx.Commit(); err != nil {
return nodeRegistrationResult{}, fmt.Errorf("commit registration: %w", err)
}
node, _ := s.nodeByID(nodeID)
return nodeRegistrationResult{
Node: node,
IssuedNodeToken: issuedNodeToken,
Node: node,
}, nil
}
func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error) {
func (s *sqliteStore) upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error) {
tx, err := s.db.Begin()
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
@ -205,7 +221,7 @@ func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) (
// Verify node exists.
var exists bool
err = tx.QueryRow("SELECT 1 FROM nodes WHERE id = ?", nodeID).Scan(&exists)
err = tx.QueryRow("SELECT 1 FROM nodes WHERE id = ? AND owner_id = ?", nodeID, ownerID).Scan(&exists)
if err != nil {
return nil, errNodeNotFound
}
@ -238,13 +254,14 @@ func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) (
}
_, err = tx.Exec(`
INSERT INTO exports (id, node_id, label, path, mount_path, capacity_bytes)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO exports (id, node_id, owner_id, label, path, mount_path, capacity_bytes)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
owner_id = excluded.owner_id,
label = excluded.label,
mount_path = excluded.mount_path,
capacity_bytes = excluded.capacity_bytes
`, exportID, nodeID, input.Label, input.Path, input.MountPath, nullableInt64(input.CapacityBytes))
`, exportID, nodeID, ownerID, input.Label, input.Path, input.MountPath, nullableInt64(input.CapacityBytes))
if err != nil {
return nil, fmt.Errorf("upsert export %q: %w", input.Path, err)
}
@ -288,10 +305,10 @@ func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) (
return s.listExportsForNode(nodeID), nil
}
func (s *sqliteStore) recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error {
func (s *sqliteStore) recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error {
result, err := s.db.Exec(
"UPDATE nodes SET status = ?, last_seen_at = ? WHERE id = ?",
request.Status, request.LastSeenAt, nodeID)
"UPDATE nodes SET status = ?, last_seen_at = ? WHERE id = ? AND owner_id = ?",
request.Status, request.LastSeenAt, nodeID, ownerID)
if err != nil {
return fmt.Errorf("update heartbeat: %w", err)
}
@ -302,8 +319,8 @@ func (s *sqliteStore) recordHeartbeat(nodeID string, request nodeHeartbeatReques
return nil
}
func (s *sqliteStore) listExports() []storageExport {
rows, err := s.db.Query("SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports ORDER BY id")
func (s *sqliteStore) listExports(ownerID string) []storageExport {
rows, err := s.db.Query("SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE owner_id = ? ORDER BY id", ownerID)
if err != nil {
return nil
}
@ -330,7 +347,7 @@ func (s *sqliteStore) listExports() []storageExport {
}
func (s *sqliteStore) listExportsForNode(nodeID string) []storageExport {
rows, err := s.db.Query("SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports WHERE node_id = ? ORDER BY id", nodeID)
rows, err := s.db.Query("SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE node_id = ? ORDER BY id", nodeID)
if err != nil {
return nil
}
@ -356,15 +373,19 @@ func (s *sqliteStore) listExportsForNode(nodeID string) []storageExport {
return exports
}
func (s *sqliteStore) exportContext(exportID string) (exportContext, bool) {
func (s *sqliteStore) exportContext(exportID string, ownerID string) (exportContext, bool) {
var e storageExport
var capacityBytes sql.NullInt64
var exportOwnerID sql.NullString
err := s.db.QueryRow(
"SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports WHERE id = ?",
exportID).Scan(&e.ID, &e.NasNodeID, &e.Label, &e.Path, &e.MountPath, &capacityBytes)
"SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE id = ? AND owner_id = ?",
exportID, ownerID).Scan(&e.ID, &e.NasNodeID, &exportOwnerID, &e.Label, &e.Path, &e.MountPath, &capacityBytes)
if err != nil {
return exportContext{}, false
}
if exportOwnerID.Valid {
e.OwnerID = exportOwnerID.String
}
if capacityBytes.Valid {
e.CapacityBytes = &capacityBytes.Int64
}
@ -383,12 +404,16 @@ func (s *sqliteStore) nodeByID(nodeID string) (nasNode, bool) {
var n nasNode
var directAddr, relayAddr sql.NullString
var lastSeenAt sql.NullString
var ownerID sql.NullString
err := s.db.QueryRow(
"SELECT id, machine_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE id = ?",
nodeID).Scan(&n.ID, &n.MachineID, &n.DisplayName, &n.AgentVersion, &n.Status, &lastSeenAt, &directAddr, &relayAddr)
"SELECT id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE id = ?",
nodeID).Scan(&n.ID, &n.MachineID, &ownerID, &n.DisplayName, &n.AgentVersion, &n.Status, &lastSeenAt, &directAddr, &relayAddr)
if err != nil {
return nasNode{}, false
}
if ownerID.Valid {
n.OwnerID = ownerID.String
}
if lastSeenAt.Valid {
n.LastSeenAt = lastSeenAt.String
}
@ -442,9 +467,13 @@ func (s *sqliteStore) nodeAuthByID(nodeID string) (nodeAuthState, bool) {
func (s *sqliteStore) scanExport(rows *sql.Rows) storageExport {
var e storageExport
var capacityBytes sql.NullInt64
if err := rows.Scan(&e.ID, &e.NasNodeID, &e.Label, &e.Path, &e.MountPath, &capacityBytes); err != nil {
var ownerID sql.NullString
if err := rows.Scan(&e.ID, &e.NasNodeID, &ownerID, &e.Label, &e.Path, &e.MountPath, &capacityBytes); err != nil {
return storageExport{}
}
if ownerID.Valid {
e.OwnerID = ownerID.String
}
if capacityBytes.Valid {
e.CapacityBytes = &capacityBytes.Int64
}

View file

@ -18,24 +18,13 @@ func newTestSQLiteApp(t *testing.T, config appConfig) (*app, *httptest.Server) {
if config.version == "" {
config.version = "test-version"
}
if config.clientToken == "" {
config.clientToken = testClientToken
}
if config.nodeBootstrapToken == "" {
config.nodeBootstrapToken = testNodeBootstrapToken
}
if config.davAuthSecret == "" {
config.davAuthSecret = "test-dav-auth-secret"
}
if config.davCredentialTTL == 0 {
config.davCredentialTTL = time.Hour
}
app, err := newApp(config, testControlPlaneNow)
if err != nil {
t.Fatalf("new app: %v", err)
}
app.now = func() time.Time { return testControlPlaneNow }
seedDefaultSessionUser(t, app)
server := httptest.NewServer(app.handler())
return app, server
@ -79,7 +68,7 @@ func TestSQLiteRegistrationAndExports(t *testing.T) {
RelayAddress: nil,
})
if registration.NodeToken == "" {
t.Fatal("expected node registration to return a node token")
t.Fatal("expected node registration to preserve the session token")
}
if registration.Node.ID != "dev-node" {
t.Fatalf("expected node ID %q, got %q", "dev-node", registration.Node.ID)
@ -142,8 +131,8 @@ func TestSQLiteReRegistrationKeepsNodeID(t *testing.T) {
if second.Node.ID != first.Node.ID {
t.Fatalf("expected re-registration to keep node ID %q, got %q", first.Node.ID, second.Node.ID)
}
if second.NodeToken != "" {
t.Fatalf("expected re-registration to not issue new token, got %q", second.NodeToken)
if second.NodeToken != first.NodeToken {
t.Fatalf("expected re-registration to keep the existing session token %q, got %q", first.NodeToken, second.NodeToken)
}
if second.Node.DisplayName != "NAS Updated" {
t.Fatalf("expected updated display name, got %q", second.Node.DisplayName)

View file

@ -31,8 +31,7 @@ type memoryStore struct {
}
type nodeRegistrationResult struct {
Node nasNode
IssuedNodeToken string
Node nasNode
}
type nodeAuthState struct {
@ -153,12 +152,12 @@ func cloneStoreState(state storeState) storeState {
return cloned
}
func (s *memoryStore) registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
func (s *memoryStore) registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
nextState := cloneStoreState(s.state)
result, err := registerNodeInState(&nextState, request, registeredAt)
result, err := registerNodeInState(&nextState, ownerID, request, registeredAt)
if err != nil {
return nodeRegistrationResult{}, err
}
@ -170,21 +169,14 @@ func (s *memoryStore) registerNode(request nodeRegistrationRequest, registeredAt
return result, nil
}
func registerNodeInState(state *storeState, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
func registerNodeInState(state *storeState, ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
nodeID, ok := state.NodeIDByMachineID[request.MachineID]
if !ok {
nodeID = nextNodeID(state)
state.NodeIDByMachineID[request.MachineID] = nodeID
}
issuedNodeToken := ""
if stringsTrimmedEmpty(state.NodeTokenHashByID[nodeID]) {
nodeToken, err := newOpaqueToken()
if err != nil {
return nodeRegistrationResult{}, err
}
state.NodeTokenHashByID[nodeID] = hashOpaqueToken(nodeToken)
issuedNodeToken = nodeToken
if existingNode, exists := state.NodesByID[nodeID]; exists && existingNode.OwnerID != "" && existingNode.OwnerID != ownerID {
return nodeRegistrationResult{}, errNodeOwnedByAnotherUser
}
node := nasNode{
@ -196,21 +188,21 @@ func registerNodeInState(state *storeState, request nodeRegistrationRequest, reg
LastSeenAt: registeredAt.UTC().Format(time.RFC3339),
DirectAddress: copyStringPointer(request.DirectAddress),
RelayAddress: copyStringPointer(request.RelayAddress),
OwnerID: ownerID,
}
state.NodesByID[nodeID] = node
return nodeRegistrationResult{
Node: node,
IssuedNodeToken: issuedNodeToken,
Node: node,
}, nil
}
func (s *memoryStore) upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error) {
func (s *memoryStore) upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error) {
s.mu.Lock()
defer s.mu.Unlock()
nextState := cloneStoreState(s.state)
exports, err := upsertExportsInState(&nextState, nodeID, request.Exports)
exports, err := upsertExportsInState(&nextState, nodeID, ownerID, request.Exports)
if err != nil {
return nil, err
}
@ -222,8 +214,9 @@ func (s *memoryStore) upsertExports(nodeID string, request nodeExportsRequest) (
return exports, nil
}
func upsertExportsInState(state *storeState, nodeID string, exports []storageExportInput) ([]storageExport, error) {
if _, ok := state.NodesByID[nodeID]; !ok {
func upsertExportsInState(state *storeState, nodeID string, ownerID string, exports []storageExportInput) ([]storageExport, error) {
node, ok := state.NodesByID[nodeID]
if !ok || node.OwnerID != ownerID {
return nil, errNodeNotFound
}
@ -250,6 +243,7 @@ func upsertExportsInState(state *storeState, nodeID string, exports []storageExp
Protocols: copyStringSlice(export.Protocols),
CapacityBytes: copyInt64Pointer(export.CapacityBytes),
Tags: copyStringSlice(export.Tags),
OwnerID: ownerID,
}
keepPaths[export.Path] = struct{}{}
}
@ -278,12 +272,12 @@ func upsertExportsInState(state *storeState, nodeID string, exports []storageExp
return nodeExports, nil
}
func (s *memoryStore) recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error {
func (s *memoryStore) recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error {
s.mu.Lock()
defer s.mu.Unlock()
nextState := cloneStoreState(s.state)
if err := recordHeartbeatInState(&nextState, nodeID, request); err != nil {
if err := recordHeartbeatInState(&nextState, nodeID, ownerID, request); err != nil {
return err
}
if err := s.persistLocked(nextState); err != nil {
@ -294,9 +288,9 @@ func (s *memoryStore) recordHeartbeat(nodeID string, request nodeHeartbeatReques
return nil
}
func recordHeartbeatInState(state *storeState, nodeID string, request nodeHeartbeatRequest) error {
func recordHeartbeatInState(state *storeState, nodeID string, ownerID string, request nodeHeartbeatRequest) error {
node, ok := state.NodesByID[nodeID]
if !ok {
if !ok || node.OwnerID != ownerID {
return errNodeNotFound
}
@ -307,12 +301,15 @@ func recordHeartbeatInState(state *storeState, nodeID string, request nodeHeartb
return nil
}
func (s *memoryStore) listExports() []storageExport {
func (s *memoryStore) listExports(ownerID string) []storageExport {
s.mu.RLock()
defer s.mu.RUnlock()
exports := make([]storageExport, 0, len(s.state.ExportsByID))
for _, export := range s.state.ExportsByID {
if export.OwnerID != ownerID {
continue
}
exports = append(exports, copyStorageExport(export))
}
@ -323,17 +320,17 @@ func (s *memoryStore) listExports() []storageExport {
return exports
}
func (s *memoryStore) exportContext(exportID string) (exportContext, bool) {
func (s *memoryStore) exportContext(exportID string, ownerID string) (exportContext, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
export, ok := s.state.ExportsByID[exportID]
if !ok {
if !ok || export.OwnerID != ownerID {
return exportContext{}, false
}
node, ok := s.state.NodesByID[export.NasNodeID]
if !ok {
if !ok || node.OwnerID != ownerID {
return exportContext{}, false
}
@ -468,6 +465,7 @@ func copyNasNode(node nasNode) nasNode {
LastSeenAt: node.LastSeenAt,
DirectAddress: copyStringPointer(node.DirectAddress),
RelayAddress: copyStringPointer(node.RelayAddress),
OwnerID: node.OwnerID,
}
}
@ -481,6 +479,7 @@ func copyStorageExport(export storageExport) storageExport {
Protocols: copyStringSlice(export.Protocols),
CapacityBytes: copyInt64Pointer(export.CapacityBytes),
Tags: copyStringSlice(export.Tags),
OwnerID: export.OwnerID,
}
}

View file

@ -5,14 +5,12 @@ import "time"
// store defines the persistence interface for the control-plane.
type store interface {
// Node management
registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error)
upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error)
recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error
listExports() []storageExport
exportContext(exportID string) (exportContext, bool)
registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error)
upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error)
recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error
listExports(ownerID string) []storageExport
exportContext(exportID string, ownerID string) (exportContext, bool)
nodeByID(nodeID string) (nasNode, bool)
nodeAuthByMachineID(machineID string) (nodeAuthState, bool)
nodeAuthByID(nodeID string) (nodeAuthState, bool)
// User auth
createUser(username string, password string) (user, error)