Secure first-loop control-plane auth and mount routing.

Protect the control-plane API with explicit bearer auth, add node-scoped
registration/heartbeat credentials, and make export mount paths an explicit
contract field so mount profiles stay correct across runtimes.

Generated with [Devin](https://cli.devin.ai/docs)

Co-Authored-By: Devin <158243242+devin-ai-integration[bot]@users.noreply.github.com>
This commit is contained in:
Harivansh Rathi 2026-04-01 14:13:14 +00:00
parent a7f85f4871
commit ed40da7326
23 changed files with 3676 additions and 124 deletions

View file

@ -7,9 +7,16 @@ It is intentionally small for now:
- `GET /health`
- `GET /version`
- `POST /api/v1/nodes/register`
- `POST /api/v1/nodes/{nodeId}/heartbeat`
- `GET /api/v1/exports`
- `POST /api/v1/mount-profiles/issue`
- `POST /api/v1/cloud-profiles/issue`
The request and response shapes must follow the contracts in
[`packages/contracts`](../../packages/contracts).
`/api/v1/*` endpoints require bearer auth. New nodes register with
`BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN`, client flows use
`BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN`, and node registration returns an
`X-BetterNAS-Node-Token` header for subsequent node-scoped register and
heartbeat calls. Multi-export registrations should also send an explicit `mountPath` per export so mount profiles can stay stable across runtimes.

View file

@ -0,0 +1,176 @@
package main
import (
"errors"
"strings"
"time"
)
type appConfig struct {
version string
nextcloudBaseURL string
statePath string
clientToken string
nodeBootstrapToken string
}
type app struct {
startedAt time.Time
now func() time.Time
config appConfig
store *memoryStore
}
func newApp(config appConfig, startedAt time.Time) (*app, error) {
config.clientToken = strings.TrimSpace(config.clientToken)
if config.clientToken == "" {
return nil, errors.New("client token is required")
}
config.nodeBootstrapToken = strings.TrimSpace(config.nodeBootstrapToken)
if config.nodeBootstrapToken == "" {
return nil, errors.New("node bootstrap token is required")
}
store, err := newMemoryStore(config.statePath)
if err != nil {
return nil, err
}
return &app{
startedAt: startedAt,
now: time.Now,
config: config,
store: store,
}, nil
}
type nextcloudBackendStatus struct {
Configured bool `json:"configured"`
BaseURL string `json:"baseUrl"`
Provider string `json:"provider"`
}
type controlPlaneHealthResponse struct {
Service string `json:"service"`
Status string `json:"status"`
Timestamp string `json:"timestamp"`
UptimeSeconds int `json:"uptimeSeconds"`
Nextcloud nextcloudBackendStatus `json:"nextcloud"`
}
type controlPlaneVersionResponse struct {
Service string `json:"service"`
Version string `json:"version"`
APIVersion string `json:"apiVersion"`
}
type nasNode struct {
ID string `json:"id"`
MachineID string `json:"machineId"`
DisplayName string `json:"displayName"`
AgentVersion string `json:"agentVersion"`
Status string `json:"status"`
LastSeenAt string `json:"lastSeenAt"`
DirectAddress *string `json:"directAddress"`
RelayAddress *string `json:"relayAddress"`
}
type storageExport struct {
ID string `json:"id"`
NasNodeID string `json:"nasNodeId"`
Label string `json:"label"`
Path string `json:"path"`
MountPath string `json:"mountPath,omitempty"`
Protocols []string `json:"protocols"`
CapacityBytes *int64 `json:"capacityBytes"`
Tags []string `json:"tags"`
}
type mountProfile struct {
ID string `json:"id"`
ExportID string `json:"exportId"`
Protocol string `json:"protocol"`
DisplayName string `json:"displayName"`
MountURL string `json:"mountUrl"`
Readonly bool `json:"readonly"`
CredentialMode string `json:"credentialMode"`
}
type cloudProfile struct {
ID string `json:"id"`
ExportID string `json:"exportId"`
Provider string `json:"provider"`
BaseURL string `json:"baseUrl"`
Path string `json:"path"`
}
type storageExportInput struct {
Label string `json:"label"`
Path string `json:"path"`
MountPath string `json:"mountPath,omitempty"`
Protocols []string `json:"protocols"`
CapacityBytes *int64 `json:"capacityBytes"`
Tags []string `json:"tags"`
}
type nodeRegistrationRequest struct {
MachineID string `json:"machineId"`
DisplayName string `json:"displayName"`
AgentVersion string `json:"agentVersion"`
DirectAddress *string `json:"directAddress"`
RelayAddress *string `json:"relayAddress"`
Exports []storageExportInput `json:"exports"`
}
type nodeHeartbeatRequest struct {
NodeID string `json:"nodeId"`
Status string `json:"status"`
LastSeenAt string `json:"lastSeenAt"`
}
type mountProfileRequest struct {
UserID string `json:"userId"`
DeviceID string `json:"deviceId"`
ExportID string `json:"exportId"`
}
type cloudProfileRequest struct {
UserID string `json:"userId"`
ExportID string `json:"exportId"`
Provider string `json:"provider"`
}
type exportContext struct {
export storageExport
node nasNode
}
func copyStringPointer(value *string) *string {
if value == nil {
return nil
}
copied := *value
return &copied
}
func copyInt64Pointer(value *int64) *int64 {
if value == nil {
return nil
}
copied := *value
return &copied
}
func copyStringSlice(values []string) []string {
if len(values) == 0 {
return []string{}
}
copied := make([]string, len(values))
copy(copied, values)
return copied
}

View file

@ -0,0 +1,23 @@
package main
import (
"net/url"
"strings"
)
const (
defaultWebDAVPath = "/dav/"
nextcloudExportPagePrefix = "/apps/betternascontrolplane/exports/"
)
func mountProfilePathForExport(mountPath string) string {
if strings.TrimSpace(mountPath) == "" {
return defaultWebDAVPath
}
return mountPath
}
func cloudProfilePathForExport(exportID string) string {
return nextcloudExportPagePrefix + url.PathEscape(exportID)
}

View file

@ -1,82 +1,21 @@
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"time"
)
type jsonObject map[string]any
func main() {
port := env("PORT", "8081")
startedAt := time.Now()
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, jsonObject{
"service": "control-plane",
"status": "ok",
"timestamp": time.Now().UTC().Format(time.RFC3339),
"uptimeSeconds": int(time.Since(startedAt).Seconds()),
"nextcloud": jsonObject{
"configured": false,
"baseUrl": env("NEXTCLOUD_BASE_URL", ""),
"provider": "nextcloud",
},
})
})
mux.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, jsonObject{
"service": "control-plane",
"version": env("BETTERNAS_VERSION", "0.1.0-dev"),
"apiVersion": "v1",
})
})
mux.HandleFunc("/api/v1/exports", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, []jsonObject{})
})
mux.HandleFunc("/api/v1/mount-profiles/issue", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, jsonObject{
"id": "dev-profile",
"exportId": "dev-export",
"protocol": "webdav",
"displayName": "Example export",
"mountUrl": env("BETTERNAS_EXAMPLE_MOUNT_URL", "http://localhost:8090/dav/"),
"readonly": false,
"credentialMode": "session-token",
})
})
mux.HandleFunc("/api/v1/cloud-profiles/issue", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, jsonObject{
"id": "dev-cloud",
"exportId": "dev-export",
"provider": "nextcloud",
"baseUrl": env("NEXTCLOUD_BASE_URL", "http://localhost:8080"),
"path": "/apps/files/files",
})
})
mux.HandleFunc("/api/v1/nodes/register", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, jsonObject{
"id": "dev-node",
"machineId": "dev-machine",
"displayName": "Development NAS",
"agentVersion": "0.1.0-dev",
"status": "online",
"lastSeenAt": time.Now().UTC().Format(time.RFC3339),
"directAddress": env("BETTERNAS_NODE_DIRECT_ADDRESS", "http://localhost:8090"),
"relayAddress": nil,
})
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
})
app, err := newAppFromEnv(time.Now())
if err != nil {
log.Fatal(err)
}
server := &http.Server{
Addr: ":" + port,
Handler: mux,
Handler: app.handler(),
ReadHeaderTimeout: 5 * time.Second,
}
@ -84,20 +23,25 @@ func main() {
log.Fatal(server.ListenAndServe())
}
func env(key, fallback string) string {
value, ok := os.LookupEnv(key)
if !ok || value == "" {
return fallback
func newAppFromEnv(startedAt time.Time) (*app, error) {
clientToken, err := requiredEnv("BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN")
if err != nil {
return nil, err
}
return value
}
func writeJSON(w http.ResponseWriter, statusCode int, payload any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(payload); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
nodeBootstrapToken, err := requiredEnv("BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN")
if err != nil {
return nil, err
}
return newApp(
appConfig{
version: env("BETTERNAS_VERSION", "0.1.0-dev"),
nextcloudBaseURL: env("NEXTCLOUD_BASE_URL", ""),
statePath: env("BETTERNAS_CONTROL_PLANE_STATE_PATH", ".state/control-plane/state.json"),
clientToken: clientToken,
nodeBootstrapToken: nodeBootstrapToken,
},
startedAt,
)
}

View file

@ -0,0 +1,591 @@
package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"
)
var (
controlPlaneBinaryOnce sync.Once
controlPlaneBinaryPath string
controlPlaneBinaryErr error
nodeAgentBinaryOnce sync.Once
nodeAgentBinaryPath string
nodeAgentBinaryErr error
)
func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
exportDir := t.TempDir()
writeExportFile(t, exportDir, "README.txt", "betterNAS export\n")
nextcloud := httptest.NewServer(http.NotFoundHandler())
defer nextcloud.Close()
nodeAgent := startNodeAgentBinary(t, exportDir)
controlPlane := startControlPlaneBinary(t, "runtime-test-version", nextcloud.URL)
client := &http.Client{Timeout: 2 * time.Second}
directAddress := nodeAgent.baseURL
registration := registerNode(t, client, controlPlane.baseURL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-runtime-1",
DisplayName: "Runtime NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{{
Label: "Photos",
Path: exportDir,
MountPath: defaultWebDAVPath,
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime"},
}},
})
if registration.Node.ID != "dev-node" {
t.Fatalf("expected node ID %q, got %q", "dev-node", registration.Node.ID)
}
if registration.NodeToken == "" {
t.Fatal("expected runtime registration to return a node token")
}
exports := getJSONAuth[[]storageExport](t, client, testClientToken, controlPlane.baseURL+"/api/v1/exports")
if len(exports) != 1 {
t.Fatalf("expected 1 export, got %d", len(exports))
}
if exports[0].ID != "dev-export" {
t.Fatalf("expected export ID %q, got %q", "dev-export", exports[0].ID)
}
if exports[0].Path != exportDir {
t.Fatalf("expected exported path %q, got %q", exportDir, exports[0].Path)
}
if exports[0].MountPath != defaultWebDAVPath {
t.Fatalf("expected mountPath %q, got %q", defaultWebDAVPath, exports[0].MountPath)
}
mount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "runtime-user",
DeviceID: "runtime-device",
ExportID: exports[0].ID,
})
if mount.MountURL != nodeAgent.baseURL+defaultWebDAVPath {
t.Fatalf("expected runtime mount URL %q, got %q", nodeAgent.baseURL+defaultWebDAVPath, mount.MountURL)
}
assertHTTPStatus(t, client, "PROPFIND", mount.MountURL, http.StatusMultiStatus)
assertMountedFileContents(t, client, mount.MountURL+"README.txt", "betterNAS export\n")
cloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "runtime-user",
ExportID: exports[0].ID,
Provider: "nextcloud",
})
if cloud.BaseURL != nextcloud.URL {
t.Fatalf("expected runtime cloud baseUrl %q, got %q", nextcloud.URL, cloud.BaseURL)
}
expectedCloudPath := cloudProfilePathForExport(exports[0].ID)
if cloud.Path != expectedCloudPath {
t.Fatalf("expected runtime cloud path %q, got %q", expectedCloudPath, cloud.Path)
}
postJSONAuthStatus(t, client, registration.NodeToken, controlPlane.baseURL+"/api/v1/nodes/"+registration.Node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: registration.Node.ID,
Status: "online",
LastSeenAt: "2025-01-02T03:04:05Z",
}, http.StatusNoContent)
}
func TestControlPlaneBinaryReRegistrationReconcilesExports(t *testing.T) {
nextcloud := httptest.NewServer(http.NotFoundHandler())
defer nextcloud.Close()
controlPlane := startControlPlaneBinary(t, "runtime-test-version", nextcloud.URL)
client := &http.Client{Timeout: 2 * time.Second}
directAddress := "http://nas.local:8090"
firstRegistration := registerNode(t, client, controlPlane.baseURL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-runtime-2",
DisplayName: "Runtime NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{
{
Label: "Docs",
Path: "/srv/docs",
MountPath: "/dav/exports/docs/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime"},
},
{
Label: "Media",
Path: "/srv/media",
MountPath: "/dav/exports/media/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime"},
},
},
})
initialExports := exportsByPath(getJSONAuth[[]storageExport](t, client, testClientToken, controlPlane.baseURL+"/api/v1/exports"))
docsExport := initialExports["/srv/docs"]
if _, ok := initialExports["/srv/media"]; !ok {
t.Fatal("expected media export to be registered")
}
secondRegistration := registerNode(t, client, controlPlane.baseURL+"/api/v1/nodes/register", firstRegistration.NodeToken, nodeRegistrationRequest{
MachineID: "machine-runtime-2",
DisplayName: "Runtime NAS Updated",
AgentVersion: "1.2.4",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{
{
Label: "Docs v2",
Path: "/srv/docs",
MountPath: "/dav/exports/docs-v2/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime", "updated"},
},
{
Label: "Backups",
Path: "/srv/backups",
MountPath: "/dav/exports/backups/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime"},
},
},
})
if secondRegistration.Node.ID != firstRegistration.Node.ID {
t.Fatalf("expected node ID %q after re-registration, got %q", firstRegistration.Node.ID, secondRegistration.Node.ID)
}
updatedExports := exportsByPath(getJSONAuth[[]storageExport](t, client, testClientToken, controlPlane.baseURL+"/api/v1/exports"))
if len(updatedExports) != 2 {
t.Fatalf("expected 2 exports after re-registration, got %d", len(updatedExports))
}
if updatedExports["/srv/docs"].ID != docsExport.ID {
t.Fatalf("expected docs export to keep ID %q, got %q", docsExport.ID, updatedExports["/srv/docs"].ID)
}
if updatedExports["/srv/docs"].Label != "Docs v2" {
t.Fatalf("expected docs export label to update, got %q", updatedExports["/srv/docs"].Label)
}
if updatedExports["/srv/docs"].MountPath != "/dav/exports/docs-v2/" {
t.Fatalf("expected docs export mountPath to update, got %q", updatedExports["/srv/docs"].MountPath)
}
if _, ok := updatedExports["/srv/media"]; ok {
t.Fatal("expected stale media export to be removed")
}
if _, ok := updatedExports["/srv/backups"]; !ok {
t.Fatal("expected backups export to be present")
}
}
func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
firstExportDir := t.TempDir()
secondExportDir := t.TempDir()
writeExportFile(t, firstExportDir, "README.txt", "first runtime export\n")
writeExportFile(t, secondExportDir, "README.txt", "second runtime export\n")
nextcloud := httptest.NewServer(http.NotFoundHandler())
defer nextcloud.Close()
nodeAgent := startNodeAgentBinaryWithExports(t, []string{firstExportDir, secondExportDir})
controlPlane := startControlPlaneBinary(t, "runtime-test-version", nextcloud.URL)
client := &http.Client{Timeout: 2 * time.Second}
firstMountPath := nodeAgentMountPathForExport(firstExportDir, 2)
secondMountPath := nodeAgentMountPathForExport(secondExportDir, 2)
directAddress := nodeAgent.baseURL
registerNode(t, client, controlPlane.baseURL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-runtime-multi",
DisplayName: "Runtime Multi NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{
{
Label: "Docs",
Path: firstExportDir,
MountPath: firstMountPath,
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime"},
},
{
Label: "Media",
Path: secondExportDir,
MountPath: secondMountPath,
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"runtime"},
},
},
})
exports := exportsByPath(getJSONAuth[[]storageExport](t, client, testClientToken, controlPlane.baseURL+"/api/v1/exports"))
firstExport := exports[firstExportDir]
secondExport := exports[secondExportDir]
firstMount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "runtime-user",
DeviceID: "runtime-device",
ExportID: firstExport.ID,
})
secondMount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "runtime-user",
DeviceID: "runtime-device",
ExportID: secondExport.ID,
})
if firstMount.MountURL == secondMount.MountURL {
t.Fatalf("expected distinct runtime mount URLs, got %q", firstMount.MountURL)
}
if firstMount.MountURL != nodeAgent.baseURL+firstMountPath {
t.Fatalf("expected first runtime mount URL %q, got %q", nodeAgent.baseURL+firstMountPath, firstMount.MountURL)
}
if secondMount.MountURL != nodeAgent.baseURL+secondMountPath {
t.Fatalf("expected second runtime mount URL %q, got %q", nodeAgent.baseURL+secondMountPath, secondMount.MountURL)
}
assertHTTPStatus(t, client, "PROPFIND", firstMount.MountURL, http.StatusMultiStatus)
assertHTTPStatus(t, client, "PROPFIND", secondMount.MountURL, http.StatusMultiStatus)
assertMountedFileContents(t, client, firstMount.MountURL+"README.txt", "first runtime export\n")
assertMountedFileContents(t, client, secondMount.MountURL+"README.txt", "second runtime export\n")
firstCloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "runtime-user",
ExportID: firstExport.ID,
Provider: "nextcloud",
})
secondCloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "runtime-user",
ExportID: secondExport.ID,
Provider: "nextcloud",
})
if firstCloud.Path == secondCloud.Path {
t.Fatalf("expected distinct runtime cloud paths, got %q", firstCloud.Path)
}
if firstCloud.Path != cloudProfilePathForExport(firstExport.ID) {
t.Fatalf("expected first runtime cloud path %q, got %q", cloudProfilePathForExport(firstExport.ID), firstCloud.Path)
}
if secondCloud.Path != cloudProfilePathForExport(secondExport.ID) {
t.Fatalf("expected second runtime cloud path %q, got %q", cloudProfilePathForExport(secondExport.ID), secondCloud.Path)
}
}
type runningBinary struct {
baseURL string
logPath string
}
func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL string) runningBinary {
t.Helper()
port := reserveTCPPort(t)
logPath := filepath.Join(t.TempDir(), "control-plane.log")
statePath := filepath.Join(t.TempDir(), "control-plane-state.json")
logFile, err := os.Create(logPath)
if err != nil {
t.Fatalf("create control-plane log file: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(ctx, buildControlPlaneBinary(t))
cmd.Env = append(
os.Environ(),
"PORT="+port,
"BETTERNAS_VERSION="+version,
"NEXTCLOUD_BASE_URL="+nextcloudBaseURL,
"BETTERNAS_CONTROL_PLANE_STATE_PATH="+statePath,
"BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN="+testClientToken,
"BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN="+testNodeBootstrapToken,
)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
_ = logFile.Close()
t.Fatalf("start control-plane binary: %v", err)
}
waitDone := make(chan error, 1)
go func() {
waitDone <- cmd.Wait()
}()
baseURL := fmt.Sprintf("http://127.0.0.1:%s", port)
waitForHTTPStatus(t, baseURL+"/health", waitDone, logPath, http.StatusOK)
registerProcessCleanup(t, ctx, cancel, cmd, waitDone, logFile, logPath, "control-plane")
return runningBinary{
baseURL: baseURL,
logPath: logPath,
}
}
func startNodeAgentBinary(t *testing.T, exportPath string) runningBinary {
return startNodeAgentBinaryWithExports(t, []string{exportPath})
}
func startNodeAgentBinaryWithExports(t *testing.T, exportPaths []string) runningBinary {
t.Helper()
port := reserveTCPPort(t)
logPath := filepath.Join(t.TempDir(), "node-agent.log")
logFile, err := os.Create(logPath)
if err != nil {
t.Fatalf("create node-agent log file: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(ctx, buildNodeAgentBinary(t))
rawExportPaths, err := json.Marshal(exportPaths)
if err != nil {
_ = logFile.Close()
t.Fatalf("marshal export paths: %v", err)
}
cmd.Env = append(
os.Environ(),
"PORT="+port,
"BETTERNAS_EXPORT_PATHS_JSON="+string(rawExportPaths),
)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
_ = logFile.Close()
t.Fatalf("start node-agent binary: %v", err)
}
waitDone := make(chan error, 1)
go func() {
waitDone <- cmd.Wait()
}()
baseURL := fmt.Sprintf("http://127.0.0.1:%s", port)
waitForHTTPStatus(t, baseURL+"/health", waitDone, logPath, http.StatusOK)
registerProcessCleanup(t, ctx, cancel, cmd, waitDone, logFile, logPath, "node-agent")
return runningBinary{
baseURL: baseURL,
logPath: logPath,
}
}
func buildControlPlaneBinary(t *testing.T) string {
t.Helper()
controlPlaneBinaryOnce.Do(func() {
_, filename, _, ok := runtime.Caller(0)
if !ok {
controlPlaneBinaryErr = errors.New("locate control-plane package directory")
return
}
tempDir, err := os.MkdirTemp("", "betternas-control-plane-*")
if err != nil {
controlPlaneBinaryErr = fmt.Errorf("create build temp dir: %w", err)
return
}
controlPlaneBinaryPath = filepath.Join(tempDir, "control-plane")
cmd := exec.Command("go", "build", "-o", controlPlaneBinaryPath, ".")
cmd.Dir = filepath.Dir(filename)
output, err := cmd.CombinedOutput()
if err != nil {
controlPlaneBinaryErr = fmt.Errorf("build control-plane binary: %w\n%s", err, output)
}
})
if controlPlaneBinaryErr != nil {
t.Fatal(controlPlaneBinaryErr)
}
return controlPlaneBinaryPath
}
func buildNodeAgentBinary(t *testing.T) string {
t.Helper()
nodeAgentBinaryOnce.Do(func() {
_, filename, _, ok := runtime.Caller(0)
if !ok {
nodeAgentBinaryErr = errors.New("locate control-plane package directory")
return
}
tempDir, err := os.MkdirTemp("", "betternas-node-agent-*")
if err != nil {
nodeAgentBinaryErr = fmt.Errorf("create build temp dir: %w", err)
return
}
nodeAgentBinaryPath = filepath.Join(tempDir, "node-agent")
cmd := exec.Command("go", "build", "-o", nodeAgentBinaryPath, "./cmd/node-agent")
cmd.Dir = filepath.Clean(filepath.Join(filepath.Dir(filename), "../../../node-agent"))
output, err := cmd.CombinedOutput()
if err != nil {
nodeAgentBinaryErr = fmt.Errorf("build node-agent binary: %w\n%s", err, output)
}
})
if nodeAgentBinaryErr != nil {
t.Fatal(nodeAgentBinaryErr)
}
return nodeAgentBinaryPath
}
func reserveTCPPort(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("reserve tcp port: %v", err)
}
defer listener.Close()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("split host port: %v", err)
}
return port
}
func waitForHTTPStatus(t *testing.T, endpoint string, waitDone <-chan error, logPath string, expectedStatus int) {
t.Helper()
deadline := time.Now().Add(10 * time.Second)
client := &http.Client{Timeout: 500 * time.Millisecond}
for time.Now().Before(deadline) {
select {
case err := <-waitDone:
logOutput, _ := os.ReadFile(logPath)
t.Fatalf("process exited before %s returned %d: %v\n%s", endpoint, expectedStatus, err, logOutput)
default:
}
response, err := client.Get(endpoint)
if err == nil {
_ = response.Body.Close()
if response.StatusCode == expectedStatus {
return
}
}
time.Sleep(100 * time.Millisecond)
}
logOutput, _ := os.ReadFile(logPath)
t.Fatalf("endpoint %s did not return %d in time\n%s", endpoint, expectedStatus, logOutput)
}
func registerProcessCleanup(t *testing.T, ctx context.Context, cancel context.CancelFunc, cmd *exec.Cmd, waitDone <-chan error, logFile *os.File, logPath string, processName string) {
t.Helper()
t.Cleanup(func() {
cancel()
defer func() {
_ = logFile.Close()
if t.Failed() {
if logOutput, err := os.ReadFile(logPath); err == nil {
t.Logf("%s logs:\n%s", processName, logOutput)
}
}
}()
select {
case err := <-waitDone:
if err != nil && ctx.Err() == nil {
t.Fatalf("%s exited unexpectedly: %v", processName, err)
}
case <-time.After(5 * time.Second):
if killErr := cmd.Process.Kill(); killErr != nil {
t.Fatalf("kill %s: %v", processName, killErr)
}
if err := <-waitDone; err != nil && ctx.Err() == nil {
t.Fatalf("%s exited unexpectedly after kill: %v", processName, err)
}
}
})
}
func assertMountedFileContents(t *testing.T, client *http.Client, endpoint string, expected string) {
t.Helper()
response, err := client.Get(endpoint)
if err != nil {
t.Fatalf("get %s: %v", endpoint, err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
t.Fatalf("get %s: expected status 200, got %d", endpoint, response.StatusCode)
}
body, err := io.ReadAll(response.Body)
if err != nil {
t.Fatalf("read %s response: %v", endpoint, err)
}
if string(body) != expected {
t.Fatalf("expected %s body %q, got %q", endpoint, expected, string(body))
}
}
func assertHTTPStatus(t *testing.T, client *http.Client, method string, endpoint string, expectedStatus int) {
t.Helper()
request, err := http.NewRequest(method, endpoint, nil)
if err != nil {
t.Fatalf("build %s request for %s: %v", method, endpoint, err)
}
response, err := client.Do(request)
if err != nil {
t.Fatalf("%s %s: %v", method, endpoint, err)
}
defer response.Body.Close()
if response.StatusCode != expectedStatus {
t.Fatalf("%s %s: expected status %d, got %d", method, endpoint, expectedStatus, response.StatusCode)
}
}
func writeExportFile(t *testing.T, directory string, name string, contents string) {
t.Helper()
if err := os.WriteFile(filepath.Join(directory, name), []byte(contents), 0o644); err != nil {
t.Fatalf("write export file %s: %v", name, err)
}
}
func nodeAgentMountPathForExport(exportPath string, exportCount int) string {
if exportCount <= 1 {
return defaultWebDAVPath
}
sum := sha256.Sum256([]byte(strings.TrimSpace(exportPath)))
return "/dav/exports/" + hex.EncodeToString(sum[:]) + "/"
}

View file

@ -0,0 +1,902 @@
package main
import (
"bytes"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
)
var (
errCloudProfileUnavailable = errors.New("nextcloud base URL is not configured")
errExportNotFound = errors.New("export not found")
errMountTargetUnavailable = errors.New("mount target is not available")
errNodeIDMismatch = errors.New("node id path and body must match")
errNodeNotFound = errors.New("node not found")
)
const (
authorizationHeader = "Authorization"
controlPlaneNodeTokenKey = "X-BetterNAS-Node-Token"
bearerScheme = "Bearer"
)
func (a *app) handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /health", a.handleHealth)
mux.HandleFunc("GET /version", a.handleVersion)
mux.HandleFunc("POST /api/v1/nodes/register", a.handleNodeRegister)
mux.HandleFunc("POST /api/v1/nodes/{nodeId}/heartbeat", a.handleNodeHeartbeat)
mux.HandleFunc("GET /api/v1/exports", a.handleExportsList)
mux.HandleFunc("POST /api/v1/mount-profiles/issue", a.handleMountProfileIssue)
mux.HandleFunc("POST /api/v1/cloud-profiles/issue", a.handleCloudProfileIssue)
return mux
}
func (a *app) handleHealth(w http.ResponseWriter, _ *http.Request) {
now := a.now().UTC()
writeJSON(w, http.StatusOK, controlPlaneHealthResponse{
Service: "control-plane",
Status: "ok",
Timestamp: now.Format(time.RFC3339),
UptimeSeconds: int(now.Sub(a.startedAt).Seconds()),
Nextcloud: nextcloudBackendStatus{
Configured: hasConfiguredNextcloudBaseURL(a.config.nextcloudBaseURL),
BaseURL: a.config.nextcloudBaseURL,
Provider: "nextcloud",
},
})
}
func (a *app) handleVersion(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, controlPlaneVersionResponse{
Service: "control-plane",
Version: a.config.version,
APIVersion: "v1",
})
}
func (a *app) handleNodeRegister(w http.ResponseWriter, r *http.Request) {
request, err := decodeNodeRegistrationRequest(w, r)
if err != nil {
writeDecodeError(w, err)
return
}
if err := validateNodeRegistrationRequest(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !a.authorizeNodeRegistration(w, r, request.MachineID) {
return
}
result, err := a.store.registerNode(request, a.now())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if result.IssuedNodeToken != "" {
w.Header().Set(controlPlaneNodeTokenKey, result.IssuedNodeToken)
}
writeJSON(w, http.StatusOK, result.Node)
}
func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
nodeID := r.PathValue("nodeId")
var request nodeHeartbeatRequest
if err := decodeJSON(w, r, &request); err != nil {
writeDecodeError(w, err)
return
}
if err := validateNodeHeartbeatRequest(nodeID, request); err != nil {
statusCode := http.StatusBadRequest
if errors.Is(err, errNodeNotFound) {
statusCode = http.StatusNotFound
}
http.Error(w, err.Error(), statusCode)
return
}
if !a.authorizeNode(w, r, nodeID) {
return
}
if err := a.store.recordHeartbeat(nodeID, request); err != nil {
statusCode := http.StatusInternalServerError
if errors.Is(err, errNodeNotFound) {
statusCode = http.StatusNotFound
}
http.Error(w, err.Error(), statusCode)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (a *app) handleExportsList(w http.ResponseWriter, r *http.Request) {
if !a.requireClientAuth(w, r) {
return
}
writeJSON(w, http.StatusOK, a.store.listExports())
}
func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
if !a.requireClientAuth(w, r) {
return
}
var request mountProfileRequest
if err := decodeJSON(w, r, &request); err != nil {
writeDecodeError(w, err)
return
}
if err := validateMountProfileRequest(request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
context, ok := a.store.exportContext(request.ExportID)
if !ok {
http.Error(w, errExportNotFound.Error(), http.StatusNotFound)
return
}
mountURL, err := buildMountURL(context)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
writeJSON(w, http.StatusOK, mountProfile{
ID: fmt.Sprintf("mount-%s-%s", request.DeviceID, context.export.ID),
ExportID: context.export.ID,
Protocol: "webdav",
DisplayName: context.export.Label,
MountURL: mountURL,
Readonly: false,
CredentialMode: "session-token",
})
}
func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
if !a.requireClientAuth(w, r) {
return
}
var request cloudProfileRequest
if err := decodeJSON(w, r, &request); err != nil {
writeDecodeError(w, err)
return
}
if err := validateCloudProfileRequest(request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
context, ok := a.store.exportContext(request.ExportID)
if !ok {
http.Error(w, errExportNotFound.Error(), http.StatusNotFound)
return
}
baseURL, err := buildCloudProfileBaseURL(a.config.nextcloudBaseURL)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
writeJSON(w, http.StatusOK, cloudProfile{
ID: fmt.Sprintf("cloud-%s-%s", request.UserID, context.export.ID),
ExportID: context.export.ID,
Provider: "nextcloud",
BaseURL: baseURL,
Path: buildCloudProfilePath(context.export.ID),
})
}
type rawObject map[string]json.RawMessage
const maxRequestBodyBytes = 1 << 20
func decodeNodeRegistrationRequest(w http.ResponseWriter, r *http.Request) (nodeRegistrationRequest, error) {
object, err := decodeRawObjectRequest(w, r)
if err != nil {
return nodeRegistrationRequest{}, err
}
if err := object.validateRequiredKeys(
"machineId",
"displayName",
"agentVersion",
"directAddress",
"relayAddress",
"exports",
); err != nil {
return nodeRegistrationRequest{}, err
}
request := nodeRegistrationRequest{}
request.MachineID, err = object.stringField("machineId")
if err != nil {
return nodeRegistrationRequest{}, err
}
request.DisplayName, err = object.stringField("displayName")
if err != nil {
return nodeRegistrationRequest{}, err
}
request.AgentVersion, err = object.stringField("agentVersion")
if err != nil {
return nodeRegistrationRequest{}, err
}
request.DirectAddress, err = object.nullableStringField("directAddress")
if err != nil {
return nodeRegistrationRequest{}, err
}
request.RelayAddress, err = object.nullableStringField("relayAddress")
if err != nil {
return nodeRegistrationRequest{}, err
}
request.Exports, err = object.storageExportInputsField("exports")
if err != nil {
return nodeRegistrationRequest{}, err
}
return request, nil
}
func decodeRawObjectRequest(w http.ResponseWriter, r *http.Request) (rawObject, error) {
var object rawObject
if err := decodeJSON(w, r, &object); err != nil {
return nil, err
}
if object == nil {
return nil, errors.New("request body must be a JSON object")
}
return object, nil
}
func decodeStorageExportInput(data json.RawMessage) (storageExportInput, error) {
object, err := decodeRawObject(data)
if err != nil {
return storageExportInput{}, err
}
if err := object.validateRequiredKeys(
"label",
"path",
"protocols",
"capacityBytes",
"tags",
); err != nil {
return storageExportInput{}, err
}
input := storageExportInput{}
input.Label, err = object.stringField("label")
if err != nil {
return storageExportInput{}, err
}
input.Path, err = object.stringField("path")
if err != nil {
return storageExportInput{}, err
}
input.MountPath, err = object.optionalStringField("mountPath")
if err != nil {
return storageExportInput{}, err
}
input.Protocols, err = object.stringSliceField("protocols")
if err != nil {
return storageExportInput{}, err
}
input.CapacityBytes, err = object.nullableInt64Field("capacityBytes")
if err != nil {
return storageExportInput{}, err
}
input.Tags, err = object.stringSliceField("tags")
if err != nil {
return storageExportInput{}, err
}
return input, nil
}
func decodeRawObject(data json.RawMessage) (rawObject, error) {
var object rawObject
if err := json.Unmarshal(data, &object); err != nil {
return nil, errors.New("must be a JSON object")
}
if object == nil {
return nil, errors.New("must be a JSON object")
}
return object, nil
}
func (o rawObject) validateRequiredKeys(fieldNames ...string) error {
for _, fieldName := range fieldNames {
if _, ok := o[fieldName]; !ok {
return fmt.Errorf("%s is required", fieldName)
}
}
return nil
}
func (o rawObject) rawField(name string) (json.RawMessage, error) {
raw, ok := o[name]
if !ok {
return nil, fmt.Errorf("%s is required", name)
}
return raw, nil
}
func (o rawObject) stringField(name string) (string, error) {
raw, err := o.rawField(name)
if err != nil {
return "", err
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return "", fmt.Errorf("%s must be a string", name)
}
return value, nil
}
func (o rawObject) nullableStringField(name string) (*string, error) {
raw, err := o.rawField(name)
if err != nil {
return nil, err
}
if isJSONNull(raw) {
return nil, nil
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%s must be a string or null", name)
}
return &value, nil
}
func (o rawObject) optionalStringField(name string) (string, error) {
raw, ok := o[name]
if !ok || isJSONNull(raw) {
return "", nil
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return "", fmt.Errorf("%s must be a string", name)
}
return value, nil
}
func (o rawObject) stringSliceField(name string) ([]string, error) {
raw, err := o.rawField(name)
if err != nil {
return nil, err
}
if isJSONNull(raw) {
return nil, fmt.Errorf("%s must be an array of strings", name)
}
var values []string
if err := json.Unmarshal(raw, &values); err != nil {
return nil, fmt.Errorf("%s must be an array of strings", name)
}
return values, nil
}
func (o rawObject) nullableInt64Field(name string) (*int64, error) {
raw, err := o.rawField(name)
if err != nil {
return nil, err
}
if isJSONNull(raw) {
return nil, nil
}
var value int64
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%s must be an integer or null", name)
}
return &value, nil
}
func (o rawObject) storageExportInputsField(name string) ([]storageExportInput, error) {
raw, err := o.rawField(name)
if err != nil {
return nil, err
}
if isJSONNull(raw) {
return nil, fmt.Errorf("%s must be an array", name)
}
var rawExports []json.RawMessage
if err := json.Unmarshal(raw, &rawExports); err != nil {
return nil, fmt.Errorf("%s must be an array", name)
}
exports := make([]storageExportInput, len(rawExports))
for index, rawExport := range rawExports {
export, err := decodeStorageExportInput(rawExport)
if err != nil {
return nil, fmt.Errorf("%s[%d].%w", name, index, err)
}
exports[index] = export
}
return exports, nil
}
func isJSONNull(raw json.RawMessage) bool {
return bytes.Equal(bytes.TrimSpace(raw), []byte("null"))
}
func validateNodeRegistrationRequest(request *nodeRegistrationRequest) error {
request.MachineID = strings.TrimSpace(request.MachineID)
if request.MachineID == "" {
return errors.New("machineId is required")
}
request.DisplayName = strings.TrimSpace(request.DisplayName)
if request.DisplayName == "" {
return errors.New("displayName is required")
}
request.AgentVersion = strings.TrimSpace(request.AgentVersion)
if request.AgentVersion == "" {
return errors.New("agentVersion is required")
}
var err error
request.DirectAddress, err = normalizeOptionalAbsoluteHTTPURL("directAddress", request.DirectAddress)
if err != nil {
return err
}
request.RelayAddress, err = normalizeOptionalAbsoluteHTTPURL("relayAddress", request.RelayAddress)
if err != nil {
return err
}
seenPaths := make(map[string]struct{}, len(request.Exports))
seenMountPaths := make(map[string]struct{}, len(request.Exports))
for index := range request.Exports {
export := &request.Exports[index]
export.Label = strings.TrimSpace(export.Label)
if export.Label == "" {
return fmt.Errorf("exports[%d].label is required", index)
}
export.Path = strings.TrimSpace(export.Path)
if export.Path == "" {
return fmt.Errorf("exports[%d].path is required", index)
}
if _, ok := seenPaths[export.Path]; ok {
return fmt.Errorf("exports[%d].path must be unique", index)
}
seenPaths[export.Path] = struct{}{}
export.MountPath = strings.TrimSpace(export.MountPath)
if len(request.Exports) > 1 && export.MountPath == "" {
return fmt.Errorf("exports[%d].mountPath is required when registering multiple exports", index)
}
if export.MountPath != "" {
normalizedMountPath, err := normalizeAbsoluteURLPath(export.MountPath)
if err != nil {
return fmt.Errorf("exports[%d].mountPath %w", index, err)
}
export.MountPath = normalizedMountPath
if _, ok := seenMountPaths[export.MountPath]; ok {
return fmt.Errorf("exports[%d].mountPath must be unique", index)
}
seenMountPaths[export.MountPath] = struct{}{}
}
if len(export.Protocols) == 0 {
return fmt.Errorf("exports[%d].protocols must not be empty", index)
}
for protocolIndex, protocol := range export.Protocols {
if protocol != "webdav" {
return fmt.Errorf("exports[%d].protocols[%d] must be webdav", index, protocolIndex)
}
}
if export.CapacityBytes != nil && *export.CapacityBytes < 0 {
return fmt.Errorf("exports[%d].capacityBytes must be greater than or equal to 0", index)
}
}
return nil
}
func validateNodeHeartbeatRequest(nodeID string, request nodeHeartbeatRequest) error {
if strings.TrimSpace(nodeID) == "" {
return errNodeNotFound
}
if strings.TrimSpace(request.NodeID) == "" {
return errors.New("nodeId is required")
}
if request.NodeID != nodeID {
return errNodeIDMismatch
}
if request.Status != "online" && request.Status != "offline" && request.Status != "degraded" {
return errors.New("status must be one of online, offline, or degraded")
}
if _, err := time.Parse(time.RFC3339, request.LastSeenAt); err != nil {
return errors.New("lastSeenAt must be a valid RFC3339 timestamp")
}
return nil
}
func validateMountProfileRequest(request mountProfileRequest) error {
if strings.TrimSpace(request.UserID) == "" {
return errors.New("userId is required")
}
if strings.TrimSpace(request.DeviceID) == "" {
return errors.New("deviceId is required")
}
if strings.TrimSpace(request.ExportID) == "" {
return errors.New("exportId is required")
}
return nil
}
func validateCloudProfileRequest(request cloudProfileRequest) error {
if strings.TrimSpace(request.UserID) == "" {
return errors.New("userId is required")
}
if strings.TrimSpace(request.ExportID) == "" {
return errors.New("exportId is required")
}
if request.Provider != "nextcloud" {
return errors.New("provider must be nextcloud")
}
return nil
}
func normalizeAbsoluteURLPath(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("must be an absolute URL path")
}
parsed, err := url.Parse(trimmed)
if err != nil {
return "", errors.New("must be an absolute URL path")
}
if parsed.Scheme != "" || parsed.Opaque != "" || parsed.Host != "" || parsed.User != nil || parsed.RawQuery != "" || parsed.Fragment != "" {
return "", errors.New("must be an absolute URL path")
}
if !strings.HasPrefix(parsed.Path, "/") {
return "", errors.New("must be an absolute URL path")
}
normalized := path.Clean(parsed.Path)
if !strings.HasPrefix(normalized, "/") {
return "", errors.New("must be an absolute URL path")
}
if !strings.HasSuffix(normalized, "/") {
normalized += "/"
}
return normalized, nil
}
func normalizeOptionalAbsoluteHTTPURL(fieldName string, value *string) (*string, error) {
if value == nil {
return nil, nil
}
normalized, err := normalizeAbsoluteHTTPURL(*value)
if err != nil {
return nil, fmt.Errorf("%s %w", fieldName, err)
}
return &normalized, nil
}
func hasConfiguredNextcloudBaseURL(baseURL string) bool {
if strings.TrimSpace(baseURL) == "" {
return false
}
_, err := normalizeAbsoluteHTTPURL(baseURL)
return err == nil
}
func buildMountURL(context exportContext) (string, error) {
address, ok := firstAddress(context.node.DirectAddress, context.node.RelayAddress)
if !ok {
return "", errMountTargetUnavailable
}
mountURL, err := buildAbsoluteHTTPURLWithPath(address, mountProfilePathForExport(context.export.MountPath))
if err != nil {
return "", errMountTargetUnavailable
}
return mountURL, nil
}
func buildCloudProfileBaseURL(baseURL string) (string, error) {
if strings.TrimSpace(baseURL) == "" {
return "", errCloudProfileUnavailable
}
normalized, err := normalizeAbsoluteHTTPURL(baseURL)
if err != nil {
return "", errCloudProfileUnavailable
}
return normalized, nil
}
func buildCloudProfilePath(exportID string) string {
return cloudProfilePathForExport(exportID)
}
func firstAddress(addresses ...*string) (string, bool) {
for _, address := range addresses {
if address == nil {
continue
}
normalized, err := normalizeAbsoluteHTTPURL(*address)
if err == nil {
return normalized, true
}
}
return "", false
}
func buildAbsoluteHTTPURLWithPath(baseAddress string, absolutePath string) (string, error) {
parsedBaseAddress, err := parseAbsoluteHTTPURL(baseAddress)
if err != nil {
return "", err
}
normalizedPath, err := joinAbsoluteURLPaths(parsedBaseAddress.Path, absolutePath)
if err != nil {
return "", err
}
parsedBaseAddress.Path = normalizedPath
parsedBaseAddress.RawPath = ""
return parsedBaseAddress.String(), nil
}
func joinAbsoluteURLPaths(basePath string, suffixPath string) (string, error) {
if strings.TrimSpace(basePath) == "" {
basePath = "/"
}
normalizedBasePath, err := normalizeAbsoluteURLPath(basePath)
if err != nil {
return "", err
}
normalizedSuffixPath, err := normalizeAbsoluteURLPath(suffixPath)
if err != nil {
return "", err
}
return normalizeAbsoluteURLPath(
path.Join(normalizedBasePath, strings.TrimPrefix(normalizedSuffixPath, "/")),
)
}
func normalizeAbsoluteHTTPURL(raw string) (string, error) {
parsed, err := parseAbsoluteHTTPURL(raw)
if err != nil {
return "", err
}
return parsed.String(), nil
}
func parseAbsoluteHTTPURL(raw string) (*url.URL, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil, errors.New("must be null or an absolute http(s) URL")
}
parsed, err := url.Parse(trimmed)
if err != nil {
return nil, errors.New("must be null or an absolute http(s) URL")
}
if parsed.Opaque != "" || parsed.Host == "" || (parsed.Scheme != "http" && parsed.Scheme != "https") {
return nil, errors.New("must be null or an absolute http(s) URL")
}
if parsed.User != nil || parsed.RawQuery != "" || parsed.Fragment != "" {
return nil, errors.New("must be null or an absolute http(s) URL without user info, query, or fragment")
}
return parsed, nil
}
func env(key, fallback string) string {
value, ok := os.LookupEnv(key)
if !ok || value == "" {
return fallback
}
return value
}
func requiredEnv(key string) (string, error) {
value, ok := os.LookupEnv(key)
if !ok || strings.TrimSpace(value) == "" {
return "", fmt.Errorf("%s is required", key)
}
return value, nil
}
func decodeJSON(w http.ResponseWriter, r *http.Request, destination any) error {
defer r.Body.Close()
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes)
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(destination); err != nil {
return err
}
var extraValue struct{}
if err := decoder.Decode(&extraValue); err != io.EOF {
return errors.New("request body must contain a single JSON object")
}
return nil
}
func writeDecodeError(w http.ResponseWriter, err error) {
var maxBytesErr *http.MaxBytesError
statusCode := http.StatusBadRequest
if errors.As(err, &maxBytesErr) {
statusCode = http.StatusRequestEntityTooLarge
}
http.Error(w, err.Error(), statusCode)
}
func writeJSON(w http.ResponseWriter, statusCode int, payload any) {
var buffer bytes.Buffer
encoder := json.NewEncoder(&buffer)
if err := encoder.Encode(payload); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(statusCode)
if _, err := w.Write(buffer.Bytes()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func (a *app) requireClientAuth(w http.ResponseWriter, r *http.Request) bool {
presentedToken, ok := bearerToken(r)
if !ok || !secureStringEquals(a.config.clientToken, presentedToken) {
writeUnauthorized(w)
return false
}
return true
}
func (a *app) authorizeNodeRegistration(w http.ResponseWriter, r *http.Request, machineID string) bool {
presentedToken, ok := bearerToken(r)
if !ok {
writeUnauthorized(w)
return false
}
authState, exists := a.store.nodeAuthByMachineID(machineID)
if !exists || strings.TrimSpace(authState.TokenHash) == "" {
if !secureStringEquals(a.config.nodeBootstrapToken, presentedToken) {
writeUnauthorized(w)
return false
}
return true
}
if !tokenHashMatches(authState.TokenHash, presentedToken) {
writeUnauthorized(w)
return false
}
return true
}
func (a *app) authorizeNode(w http.ResponseWriter, r *http.Request, nodeID string) bool {
presentedToken, ok := bearerToken(r)
if !ok {
writeUnauthorized(w)
return false
}
authState, exists := a.store.nodeAuthByID(nodeID)
if !exists {
http.Error(w, errNodeNotFound.Error(), http.StatusNotFound)
return false
}
if strings.TrimSpace(authState.TokenHash) == "" || !tokenHashMatches(authState.TokenHash, presentedToken) {
writeUnauthorized(w)
return false
}
return true
}
func bearerToken(r *http.Request) (string, bool) {
authorization := strings.TrimSpace(r.Header.Get(authorizationHeader))
if authorization == "" {
return "", false
}
scheme, token, ok := strings.Cut(authorization, " ")
if !ok || !strings.EqualFold(strings.TrimSpace(scheme), bearerScheme) {
return "", false
}
token = strings.TrimSpace(token)
if token == "" {
return "", false
}
return token, true
}
func writeUnauthorized(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", bearerScheme)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func secureStringEquals(expected string, actual string) bool {
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
}
func tokenHashMatches(expectedHash string, token string) bool {
return secureStringEquals(expectedHash, hashOpaqueToken(token))
}

View file

@ -0,0 +1,910 @@
package main
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
)
var testControlPlaneNow = time.Date(2025, time.January, 1, 12, 0, 0, 0, time.UTC)
const (
testClientToken = "test-client-token"
testNodeBootstrapToken = "test-node-bootstrap-token"
)
type registeredNode struct {
Node nasNode
NodeToken string
}
func TestControlPlaneHealthAndVersion(t *testing.T) {
t.Parallel()
_, server := newTestControlPlaneServer(t, appConfig{
version: "test-version",
nextcloudBaseURL: "http://nextcloud.test",
})
defer server.Close()
health := getJSON[controlPlaneHealthResponse](t, server.Client(), server.URL+"/health")
if health.Service != "control-plane" {
t.Fatalf("expected service control-plane, got %q", health.Service)
}
if health.Status != "ok" {
t.Fatalf("expected status ok, got %q", health.Status)
}
if health.Timestamp != testControlPlaneNow.Format(time.RFC3339) {
t.Fatalf("expected timestamp %q, got %q", testControlPlaneNow.Format(time.RFC3339), health.Timestamp)
}
if health.UptimeSeconds != 0 {
t.Fatalf("expected uptimeSeconds 0, got %d", health.UptimeSeconds)
}
if !health.Nextcloud.Configured {
t.Fatal("expected nextcloud.configured to be true")
}
if health.Nextcloud.BaseURL != "http://nextcloud.test" {
t.Fatalf("expected baseUrl http://nextcloud.test, got %q", health.Nextcloud.BaseURL)
}
if health.Nextcloud.Provider != "nextcloud" {
t.Fatalf("expected provider nextcloud, got %q", health.Nextcloud.Provider)
}
version := getJSON[controlPlaneVersionResponse](t, server.Client(), server.URL+"/version")
if version.Service != "control-plane" {
t.Fatalf("expected version service control-plane, got %q", version.Service)
}
if version.Version != "test-version" {
t.Fatalf("expected version test-version, got %q", version.Version)
}
if version.APIVersion != "v1" {
t.Fatalf("expected apiVersion v1, got %q", version.APIVersion)
}
getStatusWithAuth(t, server.Client(), "", server.URL+"/api/v1/exports", http.StatusUnauthorized)
exports := getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports")
if len(exports) != 0 {
t.Fatalf("expected no exports before registration, got %d", len(exports))
}
}
func TestControlPlaneRegistrationProfilesAndHeartbeat(t *testing.T) {
t.Parallel()
app, server := newTestControlPlaneServer(t, appConfig{
version: "test-version",
nextcloudBaseURL: "http://nextcloud.test",
})
defer server.Close()
directAddress := "http://nas.local:8090"
relayAddress := "http://nas.internal:8090"
registration := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-1",
DisplayName: "Primary NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: &relayAddress,
Exports: []storageExportInput{{
Label: "Photos",
Path: "/srv/photos",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"family"},
}},
})
if registration.NodeToken == "" {
t.Fatal("expected node registration to return a node token")
}
node := registration.Node
if node.ID != "dev-node" {
t.Fatalf("expected first node ID %q, got %q", "dev-node", node.ID)
}
if node.Status != "online" {
t.Fatalf("expected registered node to be online, got %q", node.Status)
}
if node.LastSeenAt != testControlPlaneNow.Format(time.RFC3339) {
t.Fatalf("expected lastSeenAt %q, got %q", testControlPlaneNow.Format(time.RFC3339), node.LastSeenAt)
}
if node.DirectAddress == nil || *node.DirectAddress != directAddress {
t.Fatalf("expected directAddress %q, got %#v", directAddress, node.DirectAddress)
}
if node.RelayAddress == nil || *node.RelayAddress != relayAddress {
t.Fatalf("expected relayAddress %q, got %#v", relayAddress, node.RelayAddress)
}
exports := getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports")
if len(exports) != 1 {
t.Fatalf("expected 1 export, got %d", len(exports))
}
if exports[0].ID != "dev-export" {
t.Fatalf("expected first export ID %q, got %q", "dev-export", exports[0].ID)
}
if exports[0].NasNodeID != node.ID {
t.Fatalf("expected export to belong to %q, got %q", node.ID, exports[0].NasNodeID)
}
if exports[0].Label != "Photos" {
t.Fatalf("expected export label Photos, got %q", exports[0].Label)
}
if exports[0].Path != "/srv/photos" {
t.Fatalf("expected export path %q, got %q", "/srv/photos", exports[0].Path)
}
if exports[0].MountPath != "" {
t.Fatalf("expected empty mountPath for default export, got %q", exports[0].MountPath)
}
mount := postJSONAuth[mountProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-1",
ExportID: exports[0].ID,
})
if mount.ExportID != exports[0].ID {
t.Fatalf("expected mount profile exportId %q, got %q", exports[0].ID, mount.ExportID)
}
if mount.Protocol != "webdav" {
t.Fatalf("expected mount protocol webdav, got %q", mount.Protocol)
}
if mount.DisplayName != "Photos" {
t.Fatalf("expected mount display name Photos, got %q", mount.DisplayName)
}
if mount.MountURL != "http://nas.local:8090/dav/" {
t.Fatalf("expected mount URL %q, got %q", "http://nas.local:8090/dav/", mount.MountURL)
}
if mount.Readonly {
t.Fatal("expected mount profile to be read-write")
}
if mount.CredentialMode != "session-token" {
t.Fatalf("expected credentialMode session-token, got %q", mount.CredentialMode)
}
cloud := postJSONAuth[cloudProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "user-1",
ExportID: exports[0].ID,
Provider: "nextcloud",
})
if cloud.ExportID != exports[0].ID {
t.Fatalf("expected cloud profile exportId %q, got %q", exports[0].ID, cloud.ExportID)
}
if cloud.Provider != "nextcloud" {
t.Fatalf("expected provider nextcloud, got %q", cloud.Provider)
}
if cloud.BaseURL != "http://nextcloud.test" {
t.Fatalf("expected baseUrl http://nextcloud.test, got %q", cloud.BaseURL)
}
expectedCloudPath := cloudProfilePathForExport(exports[0].ID)
if cloud.Path != expectedCloudPath {
t.Fatalf("expected cloud profile path %q, got %q", expectedCloudPath, cloud.Path)
}
postJSONAuthStatus(t, server.Client(), registration.NodeToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: node.ID,
Status: "degraded",
LastSeenAt: "2025-01-02T03:04:05Z",
}, http.StatusNoContent)
updatedNode, ok := app.store.nodeByID(node.ID)
if !ok {
t.Fatalf("expected node %q to exist after heartbeat", node.ID)
}
if updatedNode.Status != "degraded" {
t.Fatalf("expected heartbeat to update status to degraded, got %q", updatedNode.Status)
}
if updatedNode.LastSeenAt != "2025-01-02T03:04:05Z" {
t.Fatalf("expected heartbeat to update lastSeenAt, got %q", updatedNode.LastSeenAt)
}
}
func TestControlPlaneReRegistrationReconcilesExportsAndKeepsStableIDs(t *testing.T) {
t.Parallel()
app, server := newTestControlPlaneServer(t, appConfig{version: "test-version"})
defer server.Close()
directAddress := "http://nas.local:8090"
firstRegistration := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-1",
DisplayName: "Primary NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{
{
Label: "Docs",
Path: "/srv/docs",
MountPath: "/dav/exports/docs/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"work"},
},
{
Label: "Media",
Path: "/srv/media",
MountPath: "/dav/exports/media/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"personal"},
},
},
})
postJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/register", nodeRegistrationRequest{
MachineID: "machine-1",
DisplayName: "Unauthorized Re-register",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{{
Label: "Docs",
Path: "/srv/docs",
MountPath: "/dav/exports/docs/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"work"},
}},
}, http.StatusUnauthorized)
initialExports := exportsByPath(getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports"))
docsExport := initialExports["/srv/docs"]
mediaExport := initialExports["/srv/media"]
secondRegistration := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", firstRegistration.NodeToken, nodeRegistrationRequest{
MachineID: "machine-1",
DisplayName: "Primary NAS Updated",
AgentVersion: "1.2.4",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{
{
Label: "Docs v2",
Path: "/srv/docs",
MountPath: "/dav/exports/docs-v2/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"work", "updated"},
},
{
Label: "Backups",
Path: "/srv/backups",
MountPath: "/dav/exports/backups/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"system"},
},
},
})
if secondRegistration.Node.ID != firstRegistration.Node.ID {
t.Fatalf("expected re-registration to keep node ID %q, got %q", firstRegistration.Node.ID, secondRegistration.Node.ID)
}
if secondRegistration.NodeToken != "" {
t.Fatalf("expected re-registration to keep the existing node token, got %q", secondRegistration.NodeToken)
}
updatedExports := exportsByPath(getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports"))
if len(updatedExports) != 2 {
t.Fatalf("expected 2 exports after re-registration, got %d", len(updatedExports))
}
if updatedExports["/srv/docs"].ID != docsExport.ID {
t.Fatalf("expected docs export to keep ID %q, got %q", docsExport.ID, updatedExports["/srv/docs"].ID)
}
if updatedExports["/srv/docs"].Label != "Docs v2" {
t.Fatalf("expected docs export label to update, got %q", updatedExports["/srv/docs"].Label)
}
if updatedExports["/srv/docs"].MountPath != "/dav/exports/docs-v2/" {
t.Fatalf("expected docs export mountPath to update, got %q", updatedExports["/srv/docs"].MountPath)
}
if _, ok := updatedExports["/srv/media"]; ok {
t.Fatalf("expected stale media export %q to be removed", mediaExport.ID)
}
if updatedExports["/srv/backups"].ID == docsExport.ID {
t.Fatal("expected backups export to get a distinct ID")
}
storedNode, ok := app.store.nodeByID(firstRegistration.Node.ID)
if !ok {
t.Fatalf("expected node %q to exist after re-registration", firstRegistration.Node.ID)
}
if storedNode.DisplayName != "Primary NAS Updated" {
t.Fatalf("expected updated display name, got %q", storedNode.DisplayName)
}
if storedNode.AgentVersion != "1.2.4" {
t.Fatalf("expected updated agent version, got %q", storedNode.AgentVersion)
}
}
func TestControlPlaneProfilesRemainExportSpecificForConfiguredMountPaths(t *testing.T) {
t.Parallel()
_, server := newTestControlPlaneServer(t, appConfig{
version: "test-version",
nextcloudBaseURL: "http://nextcloud.test",
})
defer server.Close()
directAddress := "http://nas.local:8090"
registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-multi",
DisplayName: "Multi Export NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{
{
Label: "Docs",
Path: "/srv/docs",
MountPath: "/dav/exports/docs/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"work"},
},
{
Label: "Media",
Path: "/srv/media",
MountPath: "/dav/exports/media/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"personal"},
},
},
})
exports := exportsByPath(getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports"))
docsExport := exports["/srv/docs"]
mediaExport := exports["/srv/media"]
docsMount := postJSONAuth[mountProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-1",
ExportID: docsExport.ID,
})
mediaMount := postJSONAuth[mountProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-1",
ExportID: mediaExport.ID,
})
if docsMount.MountURL == mediaMount.MountURL {
t.Fatalf("expected distinct mount URLs for configured export paths, got %q", docsMount.MountURL)
}
if docsMount.MountURL != "http://nas.local:8090/dav/exports/docs/" {
t.Fatalf("expected docs mount URL %q, got %q", "http://nas.local:8090/dav/exports/docs/", docsMount.MountURL)
}
if mediaMount.MountURL != "http://nas.local:8090/dav/exports/media/" {
t.Fatalf("expected media mount URL %q, got %q", "http://nas.local:8090/dav/exports/media/", mediaMount.MountURL)
}
docsCloud := postJSONAuth[cloudProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "user-1",
ExportID: docsExport.ID,
Provider: "nextcloud",
})
mediaCloud := postJSONAuth[cloudProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "user-1",
ExportID: mediaExport.ID,
Provider: "nextcloud",
})
if docsCloud.Path == mediaCloud.Path {
t.Fatalf("expected distinct cloud profile paths for multi-export node, got %q", docsCloud.Path)
}
if docsCloud.Path != cloudProfilePathForExport(docsExport.ID) {
t.Fatalf("expected docs cloud path %q, got %q", cloudProfilePathForExport(docsExport.ID), docsCloud.Path)
}
if mediaCloud.Path != cloudProfilePathForExport(mediaExport.ID) {
t.Fatalf("expected media cloud path %q, got %q", cloudProfilePathForExport(mediaExport.ID), mediaCloud.Path)
}
}
func TestControlPlaneMountProfilesUseRelayAndPreserveBasePath(t *testing.T) {
t.Parallel()
_, server := newTestControlPlaneServer(t, appConfig{version: "test-version"})
defer server.Close()
relayAddress := "https://nas.example.test/control"
registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-relay",
DisplayName: "Relay NAS",
AgentVersion: "1.2.3",
DirectAddress: nil,
RelayAddress: &relayAddress,
Exports: []storageExportInput{{
Label: "Relay",
Path: "/srv/relay",
MountPath: "/dav/relay/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{},
}},
})
mount := postJSONAuth[mountProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-1",
ExportID: "dev-export",
})
if mount.MountURL != "https://nas.example.test/control/dav/relay/" {
t.Fatalf("expected relay mount URL %q, got %q", "https://nas.example.test/control/dav/relay/", mount.MountURL)
}
registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-no-target",
DisplayName: "No Target NAS",
AgentVersion: "1.2.3",
DirectAddress: nil,
RelayAddress: nil,
Exports: []storageExportInput{{
Label: "Offline",
Path: "/srv/offline",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{},
}},
})
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-2",
ExportID: "dev-export-2",
}, http.StatusServiceUnavailable)
}
func TestControlPlaneCloudProfilesRequireConfiguredBaseURLAndExistingExport(t *testing.T) {
t.Parallel()
_, server := newTestControlPlaneServer(t, appConfig{version: "test-version"})
defer server.Close()
directAddress := "http://nas.local:8090"
registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-cloud",
DisplayName: "Cloud NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{{
Label: "Photos",
Path: "/srv/photos",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{},
}},
})
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "user-1",
ExportID: "dev-export",
Provider: "nextcloud",
}, http.StatusServiceUnavailable)
_, serverWithNextcloud := newTestControlPlaneServer(t, appConfig{
version: "test-version",
nextcloudBaseURL: "http://nextcloud.test",
})
defer serverWithNextcloud.Close()
postJSONAuthStatus(t, serverWithNextcloud.Client(), testClientToken, serverWithNextcloud.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "user-1",
ExportID: "missing-export",
Provider: "nextcloud",
}, http.StatusNotFound)
}
func TestControlPlanePersistsRegistryAcrossAppRestart(t *testing.T) {
t.Parallel()
statePath := filepath.Join(t.TempDir(), "control-plane-state.json")
directAddress := "http://nas.local:8090"
_, firstServer := newTestControlPlaneServer(t, appConfig{
version: "test-version",
statePath: statePath,
})
registration := registerNode(t, firstServer.Client(), firstServer.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
MachineID: "machine-persisted",
DisplayName: "Persisted NAS",
AgentVersion: "1.2.3",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{{
Label: "Docs",
Path: "/srv/docs",
MountPath: "/dav/persisted/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"work"},
}},
})
firstServer.Close()
_, secondServer := newTestControlPlaneServer(t, appConfig{
version: "test-version",
statePath: statePath,
})
defer secondServer.Close()
exports := getJSONAuth[[]storageExport](t, secondServer.Client(), testClientToken, secondServer.URL+"/api/v1/exports")
if len(exports) != 1 {
t.Fatalf("expected persisted export after restart, got %d", len(exports))
}
if exports[0].ID != "dev-export" {
t.Fatalf("expected persisted export ID %q, got %q", "dev-export", exports[0].ID)
}
if exports[0].MountPath != "/dav/persisted/" {
t.Fatalf("expected persisted mountPath %q, got %q", "/dav/persisted/", exports[0].MountPath)
}
mount := postJSONAuth[mountProfile](t, secondServer.Client(), testClientToken, secondServer.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-1",
ExportID: exports[0].ID,
})
if mount.MountURL != "http://nas.local:8090/dav/persisted/" {
t.Fatalf("expected persisted mount URL %q, got %q", "http://nas.local:8090/dav/persisted/", mount.MountURL)
}
reRegistration := registerNode(t, secondServer.Client(), secondServer.URL+"/api/v1/nodes/register", registration.NodeToken, nodeRegistrationRequest{
MachineID: "machine-persisted",
DisplayName: "Persisted NAS Updated",
AgentVersion: "1.2.4",
DirectAddress: &directAddress,
RelayAddress: nil,
Exports: []storageExportInput{{
Label: "Docs Updated",
Path: "/srv/docs",
MountPath: "/dav/persisted/",
Protocols: []string{"webdav"},
CapacityBytes: nil,
Tags: []string{"work"},
}},
})
if reRegistration.Node.ID != registration.Node.ID {
t.Fatalf("expected persisted node ID %q, got %q", registration.Node.ID, reRegistration.Node.ID)
}
}
func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
t.Parallel()
_, server := newTestControlPlaneServer(t, appConfig{version: "test-version"})
defer server.Close()
postRawJSONStatus(t, server.Client(), server.URL+"/api/v1/nodes/register", `{
"machineId":"machine-1",
"displayName":"Primary NAS",
"agentVersion":"1.2.3",
"directAddress":"http://nas.local:8090",
"relayAddress":null,
"exports":[{"label":"Docs","path":"/srv/docs","protocols":["webdav"],"capacityBytes":null,"tags":[]}]
}`, http.StatusUnauthorized)
postRawJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/register", `{
"machineId":"machine-1",
"displayName":"Primary NAS",
"agentVersion":"1.2.3",
"relayAddress":null,
"exports":[{"label":"Docs","path":"/srv/docs","protocols":["webdav"],"capacityBytes":null,"tags":[]}]
}`, http.StatusBadRequest)
postRawJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/register", `{
"machineId":"machine-1",
"displayName":"Primary NAS",
"agentVersion":"1.2.3",
"directAddress":"nas.local:8090",
"relayAddress":null,
"exports":[{"label":"Docs","path":"/srv/docs","protocols":["webdav"],"capacityBytes":null,"tags":[]}]
}`, http.StatusBadRequest)
postRawJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/register", `{
"machineId":"machine-1",
"displayName":"Primary NAS",
"agentVersion":"1.2.3",
"directAddress":"http://nas.local:8090",
"relayAddress":null,
"exports":[
{"label":"Docs","path":"/srv/docs","mountPath":"/dav/docs/","protocols":["webdav"],"capacityBytes":null,"tags":[]},
{"label":"Docs Duplicate","path":"/srv/docs-2","mountPath":"/dav/docs/","protocols":["webdav"],"capacityBytes":null,"tags":[]}
]
}`, http.StatusBadRequest)
postRawJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/register", `{
"machineId":"machine-1",
"displayName":"Primary NAS",
"agentVersion":"1.2.3",
"directAddress":"http://nas.local:8090",
"relayAddress":null,
"exports":[
{"label":"Docs","path":"/srv/docs","mountPath":"/dav/docs/","protocols":["webdav"],"capacityBytes":null,"tags":[]},
{"label":"Media","path":"/srv/media","protocols":["webdav"],"capacityBytes":null,"tags":[]}
]
}`, http.StatusBadRequest)
response := postRawJSONAuth(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/register", `{
"machineId":"machine-1",
"displayName":"Primary NAS",
"agentVersion":"1.2.3",
"directAddress":"http://nas.local:8090",
"relayAddress":null,
"ignoredTopLevel":"ok",
"exports":[{"label":"Docs","path":"/srv/docs","mountPath":"/dav/docs/","protocols":["webdav"],"capacityBytes":null,"tags":[],"ignoredNested":"ok"}]
}`)
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
t.Fatalf("post %s: expected status 200, got %d: %s", server.URL+"/api/v1/nodes/register", response.StatusCode, body)
}
var node nasNode
if err := json.NewDecoder(response.Body).Decode(&node); err != nil {
t.Fatalf("decode registration response: %v", err)
}
nodeToken := strings.TrimSpace(response.Header.Get(controlPlaneNodeTokenKey))
if nodeToken == "" {
t.Fatal("expected node registration to return a node token")
}
if node.ID != "dev-node" {
t.Fatalf("expected node ID %q, got %q", "dev-node", node.ID)
}
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: node.ID,
Status: "online",
LastSeenAt: "2025-01-02T03:04:05Z",
}, http.StatusUnauthorized)
postJSONAuthStatus(t, server.Client(), nodeToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: "node-other",
Status: "online",
LastSeenAt: "2025-01-02T03:04:05Z",
}, http.StatusBadRequest)
postJSONAuthStatus(t, server.Client(), nodeToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: node.ID,
Status: "broken",
LastSeenAt: "2025-01-02T03:04:05Z",
}, http.StatusBadRequest)
postJSONAuthStatus(t, server.Client(), nodeToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
NodeID: node.ID,
Status: "online",
LastSeenAt: "not-a-timestamp",
}, http.StatusBadRequest)
postJSONAuthStatus(t, server.Client(), nodeToken, server.URL+"/api/v1/nodes/missing-node/heartbeat", nodeHeartbeatRequest{
NodeID: "missing-node",
Status: "online",
LastSeenAt: "2025-01-02T03:04:05Z",
}, http.StatusNotFound)
getStatusWithAuth(t, server.Client(), "", server.URL+"/api/v1/exports", http.StatusUnauthorized)
getStatusWithAuth(t, server.Client(), "wrong-client-token", server.URL+"/api/v1/exports", http.StatusUnauthorized)
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
UserID: "user-1",
DeviceID: "device-1",
ExportID: "missing-export",
}, http.StatusNotFound)
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
UserID: "user-1",
ExportID: "missing-export",
Provider: "nextcloud",
}, http.StatusNotFound)
}
func newTestControlPlaneServer(t *testing.T, config appConfig) (*app, *httptest.Server) {
t.Helper()
if config.version == "" {
config.version = "test-version"
}
if config.clientToken == "" {
config.clientToken = testClientToken
}
if config.nodeBootstrapToken == "" {
config.nodeBootstrapToken = testNodeBootstrapToken
}
app, err := newApp(config, testControlPlaneNow)
if err != nil {
t.Fatalf("new app: %v", err)
}
app.now = func() time.Time {
return testControlPlaneNow
}
server := httptest.NewServer(app.handler())
return app, server
}
func exportsByPath(exports []storageExport) map[string]storageExport {
byPath := make(map[string]storageExport, len(exports))
for _, export := range exports {
byPath[export.Path] = export
}
return byPath
}
func registerNode(t *testing.T, client *http.Client, endpoint string, token string, payload nodeRegistrationRequest) registeredNode {
t.Helper()
response := postJSONAuthResponse(t, client, token, endpoint, payload)
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(response.Body)
t.Fatalf("post %s: expected status 200, got %d: %s", endpoint, response.StatusCode, responseBody)
}
var node nasNode
if err := json.NewDecoder(response.Body).Decode(&node); err != nil {
t.Fatalf("decode %s response: %v", endpoint, err)
}
return registeredNode{
Node: node,
NodeToken: strings.TrimSpace(response.Header.Get(controlPlaneNodeTokenKey)),
}
}
func getJSON[T any](t *testing.T, client *http.Client, endpoint string) T {
t.Helper()
response := doRequest(t, client, http.MethodGet, endpoint, nil, nil)
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
t.Fatalf("get %s: expected status 200, got %d: %s", endpoint, response.StatusCode, body)
}
var payload T
if err := json.NewDecoder(response.Body).Decode(&payload); err != nil {
t.Fatalf("decode %s response: %v", endpoint, err)
}
return payload
}
func getJSONAuth[T any](t *testing.T, client *http.Client, token string, endpoint string) T {
t.Helper()
response := doRequest(t, client, http.MethodGet, endpoint, nil, authHeaders(token))
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
t.Fatalf("get %s: expected status 200, got %d: %s", endpoint, response.StatusCode, body)
}
var payload T
if err := json.NewDecoder(response.Body).Decode(&payload); err != nil {
t.Fatalf("decode %s response: %v", endpoint, err)
}
return payload
}
func getStatusWithAuth(t *testing.T, client *http.Client, token string, endpoint string, expectedStatus int) {
t.Helper()
response := doRequest(t, client, http.MethodGet, endpoint, nil, authHeaders(token))
defer response.Body.Close()
if response.StatusCode != expectedStatus {
body, _ := io.ReadAll(response.Body)
t.Fatalf("get %s: expected status %d, got %d: %s", endpoint, expectedStatus, response.StatusCode, body)
}
}
func postJSONAuth[T any](t *testing.T, client *http.Client, token string, endpoint string, payload any) T {
t.Helper()
response := postJSONAuthResponse(t, client, token, endpoint, payload)
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(response.Body)
t.Fatalf("post %s: expected status 200, got %d: %s", endpoint, response.StatusCode, responseBody)
}
var decoded T
if err := json.NewDecoder(response.Body).Decode(&decoded); err != nil {
t.Fatalf("decode %s response: %v", endpoint, err)
}
return decoded
}
func postJSONAuthStatus(t *testing.T, client *http.Client, token string, endpoint string, payload any, expectedStatus int) {
t.Helper()
response := postJSONAuthResponse(t, client, token, endpoint, payload)
defer response.Body.Close()
if response.StatusCode != expectedStatus {
body, _ := io.ReadAll(response.Body)
t.Fatalf("post %s: expected status %d, got %d: %s", endpoint, expectedStatus, response.StatusCode, body)
}
}
func postJSONAuthResponse(t *testing.T, client *http.Client, token string, endpoint string, payload any) *http.Response {
t.Helper()
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal payload for %s: %v", endpoint, err)
}
return doRequest(t, client, http.MethodPost, endpoint, bytes.NewReader(body), authHeaders(token))
}
func postRawJSONAuthStatus(t *testing.T, client *http.Client, token string, endpoint string, raw string, expectedStatus int) {
t.Helper()
response := postRawJSONAuth(t, client, token, endpoint, raw)
defer response.Body.Close()
if response.StatusCode != expectedStatus {
body, _ := io.ReadAll(response.Body)
t.Fatalf("post %s: expected status %d, got %d: %s", endpoint, expectedStatus, response.StatusCode, body)
}
}
func postRawJSONStatus(t *testing.T, client *http.Client, endpoint string, raw string, expectedStatus int) {
t.Helper()
response := doRequest(t, client, http.MethodPost, endpoint, strings.NewReader(raw), nil)
defer response.Body.Close()
if response.StatusCode != expectedStatus {
body, _ := io.ReadAll(response.Body)
t.Fatalf("post %s: expected status %d, got %d: %s", endpoint, expectedStatus, response.StatusCode, body)
}
}
func postRawJSONAuth(t *testing.T, client *http.Client, token string, endpoint string, raw string) *http.Response {
t.Helper()
return doRequest(t, client, http.MethodPost, endpoint, strings.NewReader(raw), authHeaders(token))
}
func doRequest(t *testing.T, client *http.Client, method string, endpoint string, body io.Reader, headers map[string]string) *http.Response {
t.Helper()
request, err := http.NewRequest(method, endpoint, body)
if err != nil {
t.Fatalf("build %s request for %s: %v", method, endpoint, err)
}
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
for key, value := range headers {
request.Header.Set(key, value)
}
response, err := client.Do(request)
if err != nil {
t.Fatalf("%s %s: %v", method, endpoint, err)
}
return response
}
func authHeaders(token string) map[string]string {
if strings.TrimSpace(token) == "" {
return nil
}
return map[string]string{
authorizationHeader: bearerScheme + " " + token,
}
}

View file

@ -0,0 +1,465 @@
package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
"time"
)
type storeState struct {
NextNodeOrdinal int `json:"nextNodeOrdinal"`
NextExportOrdinal int `json:"nextExportOrdinal"`
NodeIDByMachineID map[string]string `json:"nodeIdByMachineId"`
NodesByID map[string]nasNode `json:"nodesById"`
NodeTokenHashByID map[string]string `json:"nodeTokenHashById"`
ExportIDsByNodePath map[string]map[string]string `json:"exportIdsByNodePath"`
ExportsByID map[string]storageExport `json:"exportsById"`
}
type memoryStore struct {
mu sync.RWMutex
statePath string
state storeState
}
type nodeRegistrationResult struct {
Node nasNode
IssuedNodeToken string
}
type nodeAuthState struct {
NodeID string
TokenHash string
}
func newMemoryStore(statePath string) (*memoryStore, error) {
store := &memoryStore{
statePath: statePath,
state: newDefaultStoreState(),
}
if statePath == "" {
return store, nil
}
loadedState, err := loadStoreState(statePath)
if err != nil {
return nil, err
}
store.state = loadedState
return store, nil
}
func newDefaultStoreState() storeState {
return storeState{
NextNodeOrdinal: 1,
NextExportOrdinal: 1,
NodeIDByMachineID: make(map[string]string),
NodesByID: make(map[string]nasNode),
NodeTokenHashByID: make(map[string]string),
ExportIDsByNodePath: make(map[string]map[string]string),
ExportsByID: make(map[string]storageExport),
}
}
func loadStoreState(statePath string) (storeState, error) {
data, err := os.ReadFile(statePath)
if err != nil {
if os.IsNotExist(err) {
return newDefaultStoreState(), nil
}
return storeState{}, fmt.Errorf("read control-plane state %s: %w", statePath, err)
}
var state storeState
if err := json.Unmarshal(data, &state); err != nil {
return storeState{}, fmt.Errorf("decode control-plane state %s: %w", statePath, err)
}
return normalizeStoreState(state), nil
}
func normalizeStoreState(state storeState) storeState {
if state.NextNodeOrdinal < 1 {
state.NextNodeOrdinal = len(state.NodesByID) + 1
}
if state.NextExportOrdinal < 1 {
state.NextExportOrdinal = len(state.ExportsByID) + 1
}
if state.NodeIDByMachineID == nil {
state.NodeIDByMachineID = make(map[string]string)
}
if state.NodesByID == nil {
state.NodesByID = make(map[string]nasNode)
}
if state.NodeTokenHashByID == nil {
state.NodeTokenHashByID = make(map[string]string)
}
if state.ExportIDsByNodePath == nil {
state.ExportIDsByNodePath = make(map[string]map[string]string)
}
if state.ExportsByID == nil {
state.ExportsByID = make(map[string]storageExport)
}
return cloneStoreState(state)
}
func cloneStoreState(state storeState) storeState {
cloned := storeState{
NextNodeOrdinal: state.NextNodeOrdinal,
NextExportOrdinal: state.NextExportOrdinal,
NodeIDByMachineID: make(map[string]string, len(state.NodeIDByMachineID)),
NodesByID: make(map[string]nasNode, len(state.NodesByID)),
NodeTokenHashByID: make(map[string]string, len(state.NodeTokenHashByID)),
ExportIDsByNodePath: make(map[string]map[string]string, len(state.ExportIDsByNodePath)),
ExportsByID: make(map[string]storageExport, len(state.ExportsByID)),
}
for machineID, nodeID := range state.NodeIDByMachineID {
cloned.NodeIDByMachineID[machineID] = nodeID
}
for nodeID, node := range state.NodesByID {
cloned.NodesByID[nodeID] = copyNasNode(node)
}
for nodeID, tokenHash := range state.NodeTokenHashByID {
cloned.NodeTokenHashByID[nodeID] = tokenHash
}
for nodeID, exportIDsByPath := range state.ExportIDsByNodePath {
clonedExportIDsByPath := make(map[string]string, len(exportIDsByPath))
for exportPath, exportID := range exportIDsByPath {
clonedExportIDsByPath[exportPath] = exportID
}
cloned.ExportIDsByNodePath[nodeID] = clonedExportIDsByPath
}
for exportID, export := range state.ExportsByID {
cloned.ExportsByID[exportID] = copyStorageExport(export)
}
return cloned
}
func (s *memoryStore) registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
nextState := cloneStoreState(s.state)
result, err := registerNodeInState(&nextState, request, registeredAt)
if err != nil {
return nodeRegistrationResult{}, err
}
if err := s.persistLocked(nextState); err != nil {
return nodeRegistrationResult{}, err
}
s.state = nextState
return result, nil
}
func registerNodeInState(state *storeState, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
nodeID, ok := state.NodeIDByMachineID[request.MachineID]
if !ok {
nodeID = nextNodeID(state)
state.NodeIDByMachineID[request.MachineID] = nodeID
}
issuedNodeToken := ""
if stringsTrimmedEmpty(state.NodeTokenHashByID[nodeID]) {
nodeToken, err := newOpaqueToken()
if err != nil {
return nodeRegistrationResult{}, err
}
state.NodeTokenHashByID[nodeID] = hashOpaqueToken(nodeToken)
issuedNodeToken = nodeToken
}
node := nasNode{
ID: nodeID,
MachineID: request.MachineID,
DisplayName: request.DisplayName,
AgentVersion: request.AgentVersion,
Status: "online",
LastSeenAt: registeredAt.UTC().Format(time.RFC3339),
DirectAddress: copyStringPointer(request.DirectAddress),
RelayAddress: copyStringPointer(request.RelayAddress),
}
exportIDsByPath, ok := state.ExportIDsByNodePath[nodeID]
if !ok {
exportIDsByPath = make(map[string]string)
state.ExportIDsByNodePath[nodeID] = exportIDsByPath
}
keepPaths := make(map[string]struct{}, len(request.Exports))
for _, export := range request.Exports {
exportID, ok := exportIDsByPath[export.Path]
if !ok {
exportID = nextExportID(state)
exportIDsByPath[export.Path] = exportID
}
state.ExportsByID[exportID] = storageExport{
ID: exportID,
NasNodeID: nodeID,
Label: export.Label,
Path: export.Path,
MountPath: export.MountPath,
Protocols: copyStringSlice(export.Protocols),
CapacityBytes: copyInt64Pointer(export.CapacityBytes),
Tags: copyStringSlice(export.Tags),
}
keepPaths[export.Path] = struct{}{}
}
for exportPath, exportID := range exportIDsByPath {
if _, ok := keepPaths[exportPath]; ok {
continue
}
delete(exportIDsByPath, exportPath)
delete(state.ExportsByID, exportID)
}
state.NodesByID[nodeID] = node
return nodeRegistrationResult{
Node: node,
IssuedNodeToken: issuedNodeToken,
}, nil
}
func (s *memoryStore) recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error {
s.mu.Lock()
defer s.mu.Unlock()
nextState := cloneStoreState(s.state)
if err := recordHeartbeatInState(&nextState, nodeID, request); err != nil {
return err
}
if err := s.persistLocked(nextState); err != nil {
return err
}
s.state = nextState
return nil
}
func recordHeartbeatInState(state *storeState, nodeID string, request nodeHeartbeatRequest) error {
node, ok := state.NodesByID[nodeID]
if !ok {
return errNodeNotFound
}
node.Status = request.Status
node.LastSeenAt = request.LastSeenAt
state.NodesByID[nodeID] = node
return nil
}
func (s *memoryStore) listExports() []storageExport {
s.mu.RLock()
defer s.mu.RUnlock()
exports := make([]storageExport, 0, len(s.state.ExportsByID))
for _, export := range s.state.ExportsByID {
exports = append(exports, copyStorageExport(export))
}
sort.Slice(exports, func(i, j int) bool {
return exports[i].ID < exports[j].ID
})
return exports
}
func (s *memoryStore) exportContext(exportID string) (exportContext, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
export, ok := s.state.ExportsByID[exportID]
if !ok {
return exportContext{}, false
}
node, ok := s.state.NodesByID[export.NasNodeID]
if !ok {
return exportContext{}, false
}
return exportContext{
export: copyStorageExport(export),
node: copyNasNode(node),
}, true
}
func (s *memoryStore) nodeByID(nodeID string) (nasNode, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
node, ok := s.state.NodesByID[nodeID]
if !ok {
return nasNode{}, false
}
return copyNasNode(node), true
}
func (s *memoryStore) nodeAuthByMachineID(machineID string) (nodeAuthState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
nodeID, ok := s.state.NodeIDByMachineID[machineID]
if !ok {
return nodeAuthState{}, false
}
return nodeAuthState{
NodeID: nodeID,
TokenHash: s.state.NodeTokenHashByID[nodeID],
}, true
}
func (s *memoryStore) nodeAuthByID(nodeID string) (nodeAuthState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if _, ok := s.state.NodesByID[nodeID]; !ok {
return nodeAuthState{}, false
}
return nodeAuthState{
NodeID: nodeID,
TokenHash: s.state.NodeTokenHashByID[nodeID],
}, true
}
func (s *memoryStore) persistLocked(state storeState) error {
if s.statePath == "" {
return nil
}
return saveStoreState(s.statePath, state)
}
func saveStoreState(statePath string, state storeState) error {
payload, err := json.MarshalIndent(state, "", " ")
if err != nil {
return fmt.Errorf("encode control-plane state %s: %w", statePath, err)
}
payload = append(payload, '\n')
stateDir := filepath.Dir(statePath)
if err := os.MkdirAll(stateDir, 0o750); err != nil {
return fmt.Errorf("create control-plane state directory %s: %w", stateDir, err)
}
tempFile, err := os.CreateTemp(stateDir, ".control-plane-state-*.tmp")
if err != nil {
return fmt.Errorf("create control-plane state temp file in %s: %w", stateDir, err)
}
tempFilePath := tempFile.Name()
cleanupTempFile := true
defer func() {
if cleanupTempFile {
_ = os.Remove(tempFilePath)
}
}()
if err := tempFile.Chmod(0o600); err != nil {
_ = tempFile.Close()
return fmt.Errorf("chmod control-plane state temp file %s: %w", tempFilePath, err)
}
if _, err := tempFile.Write(payload); err != nil {
_ = tempFile.Close()
return fmt.Errorf("write control-plane state temp file %s: %w", tempFilePath, err)
}
if err := tempFile.Close(); err != nil {
return fmt.Errorf("close control-plane state temp file %s: %w", tempFilePath, err)
}
if err := os.Rename(tempFilePath, statePath); err != nil {
return fmt.Errorf("replace control-plane state %s: %w", statePath, err)
}
cleanupTempFile = false
return nil
}
func nextNodeID(state *storeState) string {
ordinal := state.NextNodeOrdinal
state.NextNodeOrdinal++
if ordinal == 1 {
return "dev-node"
}
return fmt.Sprintf("dev-node-%d", ordinal)
}
func nextExportID(state *storeState) string {
ordinal := state.NextExportOrdinal
state.NextExportOrdinal++
if ordinal == 1 {
return "dev-export"
}
return fmt.Sprintf("dev-export-%d", ordinal)
}
func copyNasNode(node nasNode) nasNode {
return nasNode{
ID: node.ID,
MachineID: node.MachineID,
DisplayName: node.DisplayName,
AgentVersion: node.AgentVersion,
Status: node.Status,
LastSeenAt: node.LastSeenAt,
DirectAddress: copyStringPointer(node.DirectAddress),
RelayAddress: copyStringPointer(node.RelayAddress),
}
}
func copyStorageExport(export storageExport) storageExport {
return storageExport{
ID: export.ID,
NasNodeID: export.NasNodeID,
Label: export.Label,
Path: export.Path,
MountPath: export.MountPath,
Protocols: copyStringSlice(export.Protocols),
CapacityBytes: copyInt64Pointer(export.CapacityBytes),
Tags: copyStringSlice(export.Tags),
}
}
func newOpaqueToken() (string, error) {
raw := make([]byte, 32)
if _, err := rand.Read(raw); err != nil {
return "", fmt.Errorf("generate node token: %w", err)
}
return base64.RawURLEncoding.EncodeToString(raw), nil
}
func hashOpaqueToken(token string) string {
sum := sha256.Sum256([]byte(token))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func stringsTrimmedEmpty(value string) bool {
return len(value) == 0
}

View file

@ -39,5 +39,21 @@ class PageController extends Controller {
],
);
}
}
#[NoCSRFRequired]
#[NoAdminRequired]
#[OpenAPI(OpenAPI::SCOPE_IGNORE)]
#[FrontpageRoute(verb: 'GET', url: '/exports/{exportId}')]
public function showExport(string $exportId): TemplateResponse {
return new TemplateResponse(
Application::APP_ID,
'export',
[
'appName' => 'betterNAS Control Plane',
'controlPlaneUrl' => $this->controlPlaneConfig->getBaseUrl(),
'exportId' => $exportId,
'export' => $this->controlPlaneClient->fetchExport($exportId),
],
);
}
}

View file

@ -23,8 +23,8 @@ class ControlPlaneClient {
$baseUrl = $this->controlPlaneConfig->getBaseUrl();
try {
$healthResponse = $this->request($baseUrl . '/health');
$versionResponse = $this->request($baseUrl . '/version');
$healthResponse = $this->requestObject($baseUrl . '/health');
$versionResponse = $this->requestObject($baseUrl . '/version');
return [
'available' => $healthResponse['statusCode'] === 200,
@ -46,32 +46,88 @@ class ControlPlaneClient {
}
}
/**
* @return array<string, mixed>|null
*/
public function fetchExport(string $exportId): ?array {
$baseUrl = $this->controlPlaneConfig->getBaseUrl();
try {
$exportsResponse = $this->requestList($baseUrl . '/api/v1/exports', true);
} catch (\Throwable $exception) {
$this->logger->warning('Failed to fetch betterNAS exports', [
'exception' => $exception,
'url' => $baseUrl,
'exportId' => $exportId,
]);
return null;
}
foreach ($exportsResponse['body'] as $export) {
if (!is_array($export)) {
continue;
}
if (($export['id'] ?? null) === $exportId) {
return $export;
}
}
return null;
}
/**
* @return array{statusCode: int, body: array<string, mixed>}
*/
private function request(string $url): array {
private function requestObject(string $url, bool $authenticated = false): array {
$response = $this->request($url, $authenticated);
return [
'statusCode' => $response->getStatusCode(),
'body' => $this->decodeObjectBody($response),
];
}
/**
* @return array{statusCode: int, body: array<int, array<string, mixed>>}
*/
private function requestList(string $url, bool $authenticated = false): array {
$response = $this->request($url, $authenticated);
return [
'statusCode' => $response->getStatusCode(),
'body' => $this->decodeListBody($response),
];
}
private function request(string $url, bool $authenticated = false): IResponse {
$headers = [
'Accept' => 'application/json',
];
if ($authenticated) {
$apiToken = $this->controlPlaneConfig->getApiToken();
if ($apiToken === '') {
throw new \RuntimeException('Missing betterNAS control plane API token');
}
$headers['Authorization'] = 'Bearer ' . $apiToken;
}
$client = $this->clientService->newClient();
$response = $client->get($url, [
'headers' => [
'Accept' => 'application/json',
],
return $client->get($url, [
'headers' => $headers,
'http_errors' => false,
'timeout' => 2,
'nextcloud' => [
'allow_local_address' => true,
],
]);
return [
'statusCode' => $response->getStatusCode(),
'body' => $this->decodeBody($response),
];
}
/**
* @return array<string, mixed>
*/
private function decodeBody(IResponse $response): array {
private function decodeObjectBody(IResponse $response): array {
$body = $response->getBody();
if ($body === '') {
return [];
@ -84,5 +140,29 @@ class ControlPlaneClient {
return $decoded;
}
}
/**
* @return array<int, array<string, mixed>>
*/
private function decodeListBody(IResponse $response): array {
$body = $response->getBody();
if ($body === '') {
return [];
}
$decoded = json_decode($body, true, 512, JSON_THROW_ON_ERROR);
if (!is_array($decoded)) {
return [];
}
$exports = [];
foreach ($decoded as $export) {
if (!is_array($export)) {
continue;
}
$exports[] = $export;
}
return $exports;
}
}

View file

@ -27,5 +27,17 @@ class ControlPlaneConfig {
return rtrim($configuredUrl, '/');
}
}
public function getApiToken(): string {
$environmentToken = getenv('BETTERNAS_CONTROL_PLANE_API_TOKEN');
if (is_string($environmentToken) && $environmentToken !== '') {
return $environmentToken;
}
return $this->appConfig->getValueString(
Application::APP_ID,
'control_plane_api_token',
'',
);
}
}

View file

@ -0,0 +1,54 @@
<?php
declare(strict_types=1);
use OCA\BetterNasControlplane\AppInfo\Application;
use OCP\Util;
Util::addStyle(Application::APP_ID, 'betternascontrolplane');
$export = $_['export'];
$exportId = $_['exportId'];
?>
<div class="betternas-shell">
<div class="betternas-shell__hero">
<p class="betternas-shell__eyebrow">betterNAS export</p>
<h1 class="betternas-shell__title">Export <?php p($exportId); ?></h1>
<p class="betternas-shell__copy">
This Nextcloud route is export-specific so cloud profiles can land on a concrete betterNAS surface without inventing new API shapes.
</p>
</div>
<div class="betternas-shell__grid">
<section class="betternas-shell__card">
<h2>Control plane</h2>
<dl>
<dt>Configured URL</dt>
<dd><code><?php p($_['controlPlaneUrl']); ?></code></dd>
<dt>Export ID</dt>
<dd><code><?php p($exportId); ?></code></dd>
<?php if (is_array($export)): ?>
<dt>Label</dt>
<dd><?php p((string)($export['label'] ?? '')); ?></dd>
<dt>Path</dt>
<dd><code><?php p((string)($export['path'] ?? '')); ?></code></dd>
<dt>Protocols</dt>
<dd><?php p(implode(', ', array_map('strval', (array)($export['protocols'] ?? [])))); ?></dd>
<?php else: ?>
<dt>Status</dt>
<dd>Export unavailable</dd>
<?php endif; ?>
</dl>
</section>
<section class="betternas-shell__card">
<h2>Boundary</h2>
<ul>
<li>Control-plane registry decides which export this page represents.</li>
<li>Nextcloud stays a thin cloud-facing adapter.</li>
<li>Mount-mode still flows directly to the NAS WebDAV endpoint.</li>
</ul>
</section>
</div>
</div>

View file

@ -6,5 +6,6 @@ For the scaffold it does two things:
- serves `GET /health`
- serves a WebDAV export at `/dav/`
- optionally serves multiple configured exports at deterministic `/dav/exports/<slug>/` paths via `BETTERNAS_EXPORT_PATHS_JSON`
This is the first real storage-facing surface in the monorepo.

View file

@ -0,0 +1,152 @@
package main
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"golang.org/x/net/webdav"
)
const (
defaultWebDAVPath = "/dav/"
exportScopedWebDAVPrefix = "/dav/exports/"
)
type appConfig struct {
exportPaths []string
}
type app struct {
exportMounts []exportMount
}
type exportMount struct {
exportPath string
mountPath string
}
func newApp(config appConfig) (*app, error) {
exportMounts, err := buildExportMounts(config.exportPaths)
if err != nil {
return nil, err
}
return &app{exportMounts: exportMounts}, nil
}
func newAppFromEnv() (*app, error) {
exportPaths, err := exportPathsFromEnv()
if err != nil {
return nil, err
}
return newApp(appConfig{exportPaths: exportPaths})
}
func exportPathsFromEnv() ([]string, error) {
rawValue, _ := os.LookupEnv("BETTERNAS_EXPORT_PATHS_JSON")
raw := strings.TrimSpace(rawValue)
if raw == "" {
return []string{env("BETTERNAS_EXPORT_PATH", ".")}, nil
}
var exportPaths []string
if err := json.Unmarshal([]byte(raw), &exportPaths); err != nil {
return nil, fmt.Errorf("BETTERNAS_EXPORT_PATHS_JSON must be a JSON array of strings: %w", err)
}
if len(exportPaths) == 0 {
return nil, errors.New("BETTERNAS_EXPORT_PATHS_JSON must not be empty")
}
return exportPaths, nil
}
func buildExportMounts(exportPaths []string) ([]exportMount, error) {
if len(exportPaths) == 0 {
return nil, errors.New("at least one export path is required")
}
normalizedPaths := make([]string, len(exportPaths))
seenPaths := make(map[string]struct{}, len(exportPaths))
for index, exportPath := range exportPaths {
normalizedPath := strings.TrimSpace(exportPath)
if normalizedPath == "" {
return nil, fmt.Errorf("exportPaths[%d] is required", index)
}
if _, ok := seenPaths[normalizedPath]; ok {
return nil, fmt.Errorf("exportPaths[%d] must be unique", index)
}
seenPaths[normalizedPath] = struct{}{}
normalizedPaths[index] = normalizedPath
}
mounts := make([]exportMount, 0, len(normalizedPaths)+1)
if len(normalizedPaths) == 1 {
singleExportPath := normalizedPaths[0]
mounts = append(mounts, exportMount{
exportPath: singleExportPath,
mountPath: defaultWebDAVPath,
})
mounts = append(mounts, exportMount{
exportPath: singleExportPath,
mountPath: scopedMountPathForExport(singleExportPath),
})
return mounts, nil
}
for _, exportPath := range normalizedPaths {
mounts = append(mounts, exportMount{
exportPath: exportPath,
mountPath: scopedMountPathForExport(exportPath),
})
}
return mounts, nil
}
func (a *app) handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
_, _ = w.Write([]byte("ok\n"))
})
for _, mount := range a.exportMounts {
mountPathPrefix := strings.TrimSuffix(mount.mountPath, "/")
dav := &webdav.Handler{
Prefix: mountPathPrefix,
FileSystem: webdav.Dir(mount.exportPath),
LockSystem: webdav.NewMemLS(),
}
mux.Handle(mount.mountPath, dav)
}
return mux
}
func mountProfilePathForExport(exportPath string, exportCount int) string {
// Keep /dav/ stable for the common single-export case while exposing distinct
// scoped roots when a node serves more than one export.
if exportCount <= 1 {
return defaultWebDAVPath
}
return scopedMountPathForExport(exportPath)
}
func scopedMountPathForExport(exportPath string) string {
return exportScopedWebDAVPrefix + exportRouteSlug(exportPath) + "/"
}
func exportRouteSlug(exportPath string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(exportPath)))
return hex.EncodeToString(sum[:])
}

View file

@ -0,0 +1,130 @@
package main
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func TestSingleExportServesDefaultAndScopedMountPaths(t *testing.T) {
t.Parallel()
exportDir := t.TempDir()
writeExportFile(t, exportDir, "README.txt", "single export\n")
app, err := newApp(appConfig{exportPaths: []string{exportDir}})
if err != nil {
t.Fatalf("new app: %v", err)
}
server := httptest.NewServer(app.handler())
defer server.Close()
assertHTTPStatus(t, server.Client(), "PROPFIND", server.URL+defaultWebDAVPath, http.StatusMultiStatus)
assertHTTPStatus(t, server.Client(), "PROPFIND", server.URL+scopedMountPathForExport(exportDir), http.StatusMultiStatus)
assertMountedFileContents(t, server.Client(), server.URL+defaultWebDAVPath+"README.txt", "single export\n")
assertMountedFileContents(t, server.Client(), server.URL+scopedMountPathForExport(exportDir)+"README.txt", "single export\n")
}
func TestMultipleExportsServeDistinctScopedMountPaths(t *testing.T) {
t.Parallel()
firstExportDir := t.TempDir()
secondExportDir := t.TempDir()
writeExportFile(t, firstExportDir, "README.txt", "first export\n")
writeExportFile(t, secondExportDir, "README.txt", "second export\n")
app, err := newApp(appConfig{exportPaths: []string{firstExportDir, secondExportDir}})
if err != nil {
t.Fatalf("new app: %v", err)
}
server := httptest.NewServer(app.handler())
defer server.Close()
firstMountPath := mountProfilePathForExport(firstExportDir, 2)
secondMountPath := mountProfilePathForExport(secondExportDir, 2)
if firstMountPath == secondMountPath {
t.Fatal("expected distinct mount paths for multiple exports")
}
assertHTTPStatus(t, server.Client(), "PROPFIND", server.URL+firstMountPath, http.StatusMultiStatus)
assertHTTPStatus(t, server.Client(), "PROPFIND", server.URL+secondMountPath, http.StatusMultiStatus)
assertMountedFileContents(t, server.Client(), server.URL+firstMountPath+"README.txt", "first export\n")
assertMountedFileContents(t, server.Client(), server.URL+secondMountPath+"README.txt", "second export\n")
response, err := server.Client().Get(server.URL + defaultWebDAVPath)
if err != nil {
t.Fatalf("get default multi-export mount path: %v", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusNotFound {
t.Fatalf("expected %s to return 404 for multi-export config, got %d", defaultWebDAVPath, response.StatusCode)
}
}
func TestBuildExportMountsRejectsInvalidConfigs(t *testing.T) {
t.Parallel()
if _, err := buildExportMounts(nil); err == nil {
t.Fatal("expected empty export paths to fail")
}
if _, err := buildExportMounts([]string{" "}); err == nil {
t.Fatal("expected blank export path to fail")
}
if _, err := buildExportMounts([]string{"/srv/docs", "/srv/docs"}); err == nil {
t.Fatal("expected duplicate export paths to fail")
}
}
func assertHTTPStatus(t *testing.T, client *http.Client, method string, endpoint string, expectedStatus int) {
t.Helper()
request, err := http.NewRequest(method, endpoint, nil)
if err != nil {
t.Fatalf("build %s request for %s: %v", method, endpoint, err)
}
response, err := client.Do(request)
if err != nil {
t.Fatalf("%s %s: %v", method, endpoint, err)
}
defer response.Body.Close()
if response.StatusCode != expectedStatus {
t.Fatalf("%s %s: expected status %d, got %d", method, endpoint, expectedStatus, response.StatusCode)
}
}
func assertMountedFileContents(t *testing.T, client *http.Client, endpoint string, expected string) {
t.Helper()
response, err := client.Get(endpoint)
if err != nil {
t.Fatalf("get %s: %v", endpoint, err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
t.Fatalf("get %s: expected status 200, got %d", endpoint, response.StatusCode)
}
body, err := io.ReadAll(response.Body)
if err != nil {
t.Fatalf("read %s response: %v", endpoint, err)
}
if string(body) != expected {
t.Fatalf("expected %s body %q, got %q", endpoint, expected, string(body))
}
}
func writeExportFile(t *testing.T, directory string, name string, contents string) {
t.Helper()
if err := os.WriteFile(filepath.Join(directory, name), []byte(contents), 0o644); err != nil {
t.Fatalf("write export file %s: %v", name, err)
}
}

View file

@ -5,34 +5,22 @@ import (
"net/http"
"os"
"time"
"golang.org/x/net/webdav"
)
func main() {
port := env("PORT", "8090")
exportPath := env("BETTERNAS_EXPORT_PATH", ".")
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
_, _ = w.Write([]byte("ok\n"))
})
dav := &webdav.Handler{
Prefix: "/dav",
FileSystem: webdav.Dir(exportPath),
LockSystem: webdav.NewMemLS(),
app, err := newAppFromEnv()
if err != nil {
log.Fatal(err)
}
mux.Handle("/dav/", dav)
server := &http.Server{
Addr: ":" + port,
Handler: mux,
Handler: app.handler(),
ReadHeaderTimeout: 5 * time.Second,
}
log.Printf("betterNAS node agent serving %s on :%s", exportPath, port)
log.Printf("betterNAS node agent listening on :%s", port)
log.Fatal(server.ListenAndServe())
}