mirror of
https://github.com/harivansh-afk/betterNAS.git
synced 2026-04-17 11:04:54 +00:00
user-owned DAVs
This commit is contained in:
parent
ca5014750b
commit
0a3234d617
35 changed files with 732 additions and 777 deletions
|
|
@ -1,8 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -11,10 +9,6 @@ type appConfig struct {
|
|||
nextcloudBaseURL string
|
||||
statePath string
|
||||
dbPath string
|
||||
clientToken string
|
||||
nodeBootstrapToken string
|
||||
davAuthSecret string
|
||||
davCredentialTTL time.Duration
|
||||
sessionTTL time.Duration
|
||||
registrationEnabled bool
|
||||
corsOrigin string
|
||||
|
|
@ -28,21 +22,6 @@ type app struct {
|
|||
}
|
||||
|
||||
func newApp(config appConfig, startedAt time.Time) (*app, error) {
|
||||
config.clientToken = strings.TrimSpace(config.clientToken)
|
||||
|
||||
config.nodeBootstrapToken = strings.TrimSpace(config.nodeBootstrapToken)
|
||||
if config.nodeBootstrapToken == "" {
|
||||
return nil, errors.New("node bootstrap token is required")
|
||||
}
|
||||
|
||||
config.davAuthSecret = strings.TrimSpace(config.davAuthSecret)
|
||||
if config.davAuthSecret == "" {
|
||||
return nil, errors.New("dav auth secret is required")
|
||||
}
|
||||
if config.davCredentialTTL <= 0 {
|
||||
return nil, errors.New("dav credential ttl must be greater than 0")
|
||||
}
|
||||
|
||||
var s store
|
||||
var err error
|
||||
if config.dbPath != "" {
|
||||
|
|
@ -91,6 +70,7 @@ type nasNode struct {
|
|||
LastSeenAt string `json:"lastSeenAt"`
|
||||
DirectAddress *string `json:"directAddress"`
|
||||
RelayAddress *string `json:"relayAddress"`
|
||||
OwnerID string `json:"-"`
|
||||
}
|
||||
|
||||
type storageExport struct {
|
||||
|
|
@ -102,6 +82,7 @@ type storageExport struct {
|
|||
Protocols []string `json:"protocols"`
|
||||
CapacityBytes *int64 `json:"capacityBytes"`
|
||||
Tags []string `json:"tags"`
|
||||
OwnerID string `json:"-"`
|
||||
}
|
||||
|
||||
type mountProfile struct {
|
||||
|
|
|
|||
|
|
@ -196,21 +196,25 @@ func TestAuthSessionUsedForClientEndpoints(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAuthStaticTokenFallback(t *testing.T) {
|
||||
func TestAuthSessionIsTheOnlyClientAuthPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
clientToken: "static-fallback-token",
|
||||
version: "test-version",
|
||||
registrationEnabled: true,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Static token should work for client endpoints.
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), "static-fallback-token", server.URL+"/api/v1/exports")
|
||||
reg := postJSONAuthCreated[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "sessiononly",
|
||||
Password: "password123",
|
||||
})
|
||||
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), reg.Token, server.URL+"/api/v1/exports")
|
||||
if exports == nil {
|
||||
t.Fatal("expected exports list, got nil")
|
||||
}
|
||||
|
||||
// Wrong token should fail.
|
||||
getStatusWithAuth(t, server.Client(), "static-fallback-token", server.URL+"/api/v1/exports", http.StatusUnauthorized)
|
||||
getStatusWithAuth(t, server.Client(), "wrong", server.URL+"/api/v1/exports", http.StatusUnauthorized)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
|
@ -25,44 +26,52 @@ func main() {
|
|||
}
|
||||
|
||||
func newAppFromEnv(startedAt time.Time) (*app, error) {
|
||||
nodeBootstrapToken, err := requiredEnv("BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
davAuthSecret, err := requiredEnv("BETTERNAS_DAV_AUTH_SECRET")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
davCredentialTTL, err := parseRequiredDurationEnv("BETTERNAS_DAV_CREDENTIAL_TTL")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sessionTTL time.Duration
|
||||
rawSessionTTL := strings.TrimSpace(env("BETTERNAS_SESSION_TTL", "720h"))
|
||||
if rawSessionTTL != "" {
|
||||
sessionTTL, err = time.ParseDuration(rawSessionTTL)
|
||||
parsedSessionTTL, err := time.ParseDuration(rawSessionTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionTTL = parsedSessionTTL
|
||||
}
|
||||
|
||||
return newApp(
|
||||
app, err := newApp(
|
||||
appConfig{
|
||||
version: env("BETTERNAS_VERSION", "0.1.0-dev"),
|
||||
nextcloudBaseURL: env("NEXTCLOUD_BASE_URL", ""),
|
||||
statePath: env("BETTERNAS_CONTROL_PLANE_STATE_PATH", ".state/control-plane/state.json"),
|
||||
dbPath: env("BETTERNAS_CONTROL_PLANE_DB_PATH", ""),
|
||||
clientToken: env("BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN", ""),
|
||||
nodeBootstrapToken: nodeBootstrapToken,
|
||||
davAuthSecret: davAuthSecret,
|
||||
davCredentialTTL: davCredentialTTL,
|
||||
dbPath: env("BETTERNAS_CONTROL_PLANE_DB_PATH", ".state/control-plane/betternas.db"),
|
||||
sessionTTL: sessionTTL,
|
||||
registrationEnabled: env("BETTERNAS_REGISTRATION_ENABLED", "true") == "true",
|
||||
corsOrigin: env("BETTERNAS_CORS_ORIGIN", ""),
|
||||
},
|
||||
startedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := seedDefaultUserFromEnv(app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
func seedDefaultUserFromEnv(app *app) error {
|
||||
username := strings.TrimSpace(env("BETTERNAS_USERNAME", ""))
|
||||
password := env("BETTERNAS_PASSWORD", "")
|
||||
if username == "" || password == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := app.store.createUser(username, password); err != nil {
|
||||
if errors.Is(err, errUsernameTaken) {
|
||||
_, authErr := app.store.authenticateUser(username, password)
|
||||
return authErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,89 +1,12 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const mountCredentialModeBasicAuth = "basic-auth"
|
||||
|
||||
// mountCredentialUsernameTokenBytes controls the random token size in mount
|
||||
// credential usernames (e.g. "mount-<token>"). The username is also embedded
|
||||
// inside the signed password payload, so longer tokens produce longer
|
||||
// passwords. macOS WebDAVFS truncates Basic Auth passwords at 255 bytes,
|
||||
// which corrupts the HMAC signature and causes auth failures. 24 bytes
|
||||
// (32 base64url chars) keeps the total password under 250 characters with
|
||||
// margin for longer node IDs and mount paths.
|
||||
const mountCredentialUsernameTokenBytes = 24
|
||||
|
||||
type signedMountCredentialClaims struct {
|
||||
Version int `json:"v"`
|
||||
NodeID string `json:"nodeId"`
|
||||
MountPath string `json:"mountPath"`
|
||||
Username string `json:"username"`
|
||||
Readonly bool `json:"readonly"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
}
|
||||
|
||||
func issueMountCredential(secret string, nodeID string, mountPath string, readonly bool, issuedAt time.Time, ttl time.Duration) (string, mountCredential, error) {
|
||||
credentialID, err := newOpaqueToken()
|
||||
if err != nil {
|
||||
return "", mountCredential{}, err
|
||||
}
|
||||
|
||||
usernameToken, err := newMountCredentialUsernameToken()
|
||||
if err != nil {
|
||||
return "", mountCredential{}, err
|
||||
}
|
||||
|
||||
claims := signedMountCredentialClaims{
|
||||
Version: 1,
|
||||
NodeID: nodeID,
|
||||
MountPath: mountPath,
|
||||
Username: "mount-" + usernameToken,
|
||||
Readonly: readonly,
|
||||
ExpiresAt: issuedAt.UTC().Add(ttl).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
password, err := signMountCredentialClaims(secret, claims)
|
||||
if err != nil {
|
||||
return "", mountCredential{}, err
|
||||
}
|
||||
|
||||
return "mount-" + credentialID, mountCredential{
|
||||
func buildAccountMountCredential(username string) mountCredential {
|
||||
return mountCredential{
|
||||
Mode: mountCredentialModeBasicAuth,
|
||||
Username: claims.Username,
|
||||
Password: password,
|
||||
ExpiresAt: claims.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newMountCredentialUsernameToken() (string, error) {
|
||||
raw := make([]byte, mountCredentialUsernameTokenBytes)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", fmt.Errorf("generate mount credential username token: %w", err)
|
||||
Username: username,
|
||||
Password: "",
|
||||
ExpiresAt: "",
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(raw), nil
|
||||
}
|
||||
|
||||
func signMountCredentialClaims(secret string, claims signedMountCredentialClaims) (string, error) {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encode mount credential claims: %w", err)
|
||||
}
|
||||
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
signature := signMountCredentialPayload(secret, encodedPayload)
|
||||
return encodedPayload + "." + signature, nil
|
||||
}
|
||||
|
||||
func signMountCredentialPayload(secret string, encodedPayload string) string {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
_, _ = mac.Write([]byte(encodedPayload))
|
||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,10 +31,7 @@ var (
|
|||
nodeAgentBinaryErr error
|
||||
)
|
||||
|
||||
const (
|
||||
runtimeDAVAuthSecret = "runtime-dav-auth-secret"
|
||||
runtimeDAVCredentialTTL = "1h"
|
||||
)
|
||||
const runtimeUsername = "runtime-user"
|
||||
|
||||
func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
|
||||
exportDir := t.TempDir()
|
||||
|
|
@ -47,7 +44,7 @@ func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
|
|||
nodeAgent := startNodeAgentBinaryWithExports(t, controlPlane.baseURL, []string{exportDir}, "machine-runtime-1")
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
|
||||
exports := waitForExportsByPath(t, client, controlPlane.baseURL+"/api/v1/exports", []string{exportDir})
|
||||
exports := waitForExportsByPath(t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/exports", []string{exportDir})
|
||||
export := exports[exportDir]
|
||||
if export.ID != "dev-export" {
|
||||
t.Fatalf("expected export ID %q, got %q", "dev-export", export.ID)
|
||||
|
|
@ -56,7 +53,7 @@ func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
|
|||
t.Fatalf("expected mountPath %q, got %q", defaultWebDAVPath, export.MountPath)
|
||||
}
|
||||
|
||||
mount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
|
||||
mount := postJSONAuth[mountProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{
|
||||
ExportID: export.ID,
|
||||
})
|
||||
if mount.MountURL != nodeAgent.baseURL+defaultWebDAVPath {
|
||||
|
|
@ -66,11 +63,11 @@ func TestControlPlaneBinaryMountLoopIntegration(t *testing.T) {
|
|||
t.Fatalf("expected mount credential mode %q, got %q", mountCredentialModeBasicAuth, mount.Credential.Mode)
|
||||
}
|
||||
|
||||
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", mount.MountURL, mount.Credential.Username, mount.Credential.Password, http.StatusMultiStatus)
|
||||
assertMountedFileContentsWithBasicAuth(t, client, mount.MountURL+"README.txt", mount.Credential.Username, mount.Credential.Password, "betterNAS export\n")
|
||||
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", mount.MountURL, controlPlane.username, controlPlane.password, http.StatusMultiStatus)
|
||||
assertMountedFileContentsWithBasicAuth(t, client, mount.MountURL+"README.txt", controlPlane.username, controlPlane.password, "betterNAS export\n")
|
||||
|
||||
cloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
UserID: "runtime-user",
|
||||
cloud := postJSONAuth[cloudProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
UserID: controlPlane.userID,
|
||||
ExportID: export.ID,
|
||||
Provider: "nextcloud",
|
||||
})
|
||||
|
|
@ -97,12 +94,12 @@ func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
|
|||
|
||||
firstMountPath := nodeAgentMountPathForExport(firstExportDir, 2)
|
||||
secondMountPath := nodeAgentMountPathForExport(secondExportDir, 2)
|
||||
exports := waitForExportsByPath(t, client, controlPlane.baseURL+"/api/v1/exports", []string{firstExportDir, secondExportDir})
|
||||
exports := waitForExportsByPath(t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/exports", []string{firstExportDir, secondExportDir})
|
||||
firstExport := exports[firstExportDir]
|
||||
secondExport := exports[secondExportDir]
|
||||
|
||||
firstMount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: firstExport.ID})
|
||||
secondMount := postJSONAuth[mountProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: secondExport.ID})
|
||||
firstMount := postJSONAuth[mountProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: firstExport.ID})
|
||||
secondMount := postJSONAuth[mountProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: secondExport.ID})
|
||||
if firstMount.MountURL == secondMount.MountURL {
|
||||
t.Fatalf("expected distinct runtime mount URLs, got %q", firstMount.MountURL)
|
||||
}
|
||||
|
|
@ -113,18 +110,18 @@ func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
|
|||
t.Fatalf("expected second runtime mount URL %q, got %q", nodeAgent.baseURL+secondMountPath, secondMount.MountURL)
|
||||
}
|
||||
|
||||
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", firstMount.MountURL, firstMount.Credential.Username, firstMount.Credential.Password, http.StatusMultiStatus)
|
||||
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", secondMount.MountURL, secondMount.Credential.Username, secondMount.Credential.Password, http.StatusMultiStatus)
|
||||
assertMountedFileContentsWithBasicAuth(t, client, firstMount.MountURL+"README.txt", firstMount.Credential.Username, firstMount.Credential.Password, "first runtime export\n")
|
||||
assertMountedFileContentsWithBasicAuth(t, client, secondMount.MountURL+"README.txt", secondMount.Credential.Username, secondMount.Credential.Password, "second runtime export\n")
|
||||
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", firstMount.MountURL, controlPlane.username, controlPlane.password, http.StatusMultiStatus)
|
||||
assertHTTPStatusWithBasicAuth(t, client, "PROPFIND", secondMount.MountURL, controlPlane.username, controlPlane.password, http.StatusMultiStatus)
|
||||
assertMountedFileContentsWithBasicAuth(t, client, firstMount.MountURL+"README.txt", controlPlane.username, controlPlane.password, "first runtime export\n")
|
||||
assertMountedFileContentsWithBasicAuth(t, client, secondMount.MountURL+"README.txt", controlPlane.username, controlPlane.password, "second runtime export\n")
|
||||
|
||||
firstCloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
UserID: "runtime-user",
|
||||
firstCloud := postJSONAuth[cloudProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
UserID: controlPlane.userID,
|
||||
ExportID: firstExport.ID,
|
||||
Provider: "nextcloud",
|
||||
})
|
||||
secondCloud := postJSONAuth[cloudProfile](t, client, testClientToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
UserID: "runtime-user",
|
||||
secondCloud := postJSONAuth[cloudProfile](t, client, controlPlane.sessionToken, controlPlane.baseURL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
UserID: controlPlane.userID,
|
||||
ExportID: secondExport.ID,
|
||||
Provider: "nextcloud",
|
||||
})
|
||||
|
|
@ -140,8 +137,12 @@ func TestControlPlaneBinaryMultiExportProfilesStayDistinct(t *testing.T) {
|
|||
}
|
||||
|
||||
type runningBinary struct {
|
||||
baseURL string
|
||||
logPath string
|
||||
baseURL string
|
||||
logPath string
|
||||
sessionToken string
|
||||
username string
|
||||
password string
|
||||
userID string
|
||||
}
|
||||
|
||||
func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL string) runningBinary {
|
||||
|
|
@ -149,7 +150,7 @@ func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL stri
|
|||
|
||||
port := reserveTCPPort(t)
|
||||
logPath := filepath.Join(t.TempDir(), "control-plane.log")
|
||||
statePath := filepath.Join(t.TempDir(), "control-plane-state.json")
|
||||
dbPath := filepath.Join(t.TempDir(), "control-plane.db")
|
||||
logFile, err := os.Create(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create control-plane log file: %v", err)
|
||||
|
|
@ -162,11 +163,8 @@ func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL stri
|
|||
"PORT="+port,
|
||||
"BETTERNAS_VERSION="+version,
|
||||
"NEXTCLOUD_BASE_URL="+nextcloudBaseURL,
|
||||
"BETTERNAS_CONTROL_PLANE_STATE_PATH="+statePath,
|
||||
"BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN="+testClientToken,
|
||||
"BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN="+testNodeBootstrapToken,
|
||||
"BETTERNAS_DAV_AUTH_SECRET="+runtimeDAVAuthSecret,
|
||||
"BETTERNAS_DAV_CREDENTIAL_TTL="+runtimeDAVCredentialTTL,
|
||||
"BETTERNAS_CONTROL_PLANE_DB_PATH="+dbPath,
|
||||
"BETTERNAS_REGISTRATION_ENABLED=true",
|
||||
)
|
||||
cmd.Stdout = logFile
|
||||
cmd.Stderr = logFile
|
||||
|
|
@ -183,11 +181,16 @@ func startControlPlaneBinary(t *testing.T, version string, nextcloudBaseURL stri
|
|||
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%s", port)
|
||||
waitForHTTPStatus(t, baseURL+"/health", waitDone, logPath, http.StatusOK)
|
||||
session := registerRuntimeUser(t, &http.Client{Timeout: 2 * time.Second}, baseURL)
|
||||
registerProcessCleanup(t, ctx, cancel, cmd, waitDone, logFile, logPath, "control-plane")
|
||||
|
||||
return runningBinary{
|
||||
baseURL: baseURL,
|
||||
logPath: logPath,
|
||||
baseURL: baseURL,
|
||||
logPath: logPath,
|
||||
sessionToken: session.Token,
|
||||
username: runtimeUsername,
|
||||
password: testPassword,
|
||||
userID: session.User.ID,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -197,7 +200,6 @@ func startNodeAgentBinaryWithExports(t *testing.T, controlPlaneBaseURL string, e
|
|||
port := reserveTCPPort(t)
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%s", port)
|
||||
logPath := filepath.Join(t.TempDir(), "node-agent.log")
|
||||
nodeTokenPath := filepath.Join(t.TempDir(), "node-token")
|
||||
logFile, err := os.Create(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create node-agent log file: %v", err)
|
||||
|
|
@ -215,12 +217,11 @@ func startNodeAgentBinaryWithExports(t *testing.T, controlPlaneBaseURL string, e
|
|||
"PORT="+port,
|
||||
"BETTERNAS_EXPORT_PATHS_JSON="+string(rawExportPaths),
|
||||
"BETTERNAS_CONTROL_PLANE_URL="+controlPlaneBaseURL,
|
||||
"BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN="+testNodeBootstrapToken,
|
||||
"BETTERNAS_NODE_TOKEN_PATH="+nodeTokenPath,
|
||||
"BETTERNAS_USERNAME="+runtimeUsername,
|
||||
"BETTERNAS_PASSWORD="+testPassword,
|
||||
"BETTERNAS_NODE_MACHINE_ID="+machineID,
|
||||
"BETTERNAS_NODE_DISPLAY_NAME="+machineID,
|
||||
"BETTERNAS_NODE_DIRECT_ADDRESS="+baseURL,
|
||||
"BETTERNAS_DAV_AUTH_SECRET="+runtimeDAVAuthSecret,
|
||||
"BETTERNAS_VERSION=runtime-test-version",
|
||||
)
|
||||
cmd.Stdout = logFile
|
||||
|
|
@ -245,12 +246,12 @@ func startNodeAgentBinaryWithExports(t *testing.T, controlPlaneBaseURL string, e
|
|||
}
|
||||
}
|
||||
|
||||
func waitForExportsByPath(t *testing.T, client *http.Client, endpoint string, expectedPaths []string) map[string]storageExport {
|
||||
func waitForExportsByPath(t *testing.T, client *http.Client, token string, endpoint string, expectedPaths []string) map[string]storageExport {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(10 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
exports := getJSONAuth[[]storageExport](t, client, testClientToken, endpoint)
|
||||
exports := getJSONAuth[[]storageExport](t, client, token, endpoint)
|
||||
exportsByPath := exportsByPath(exports)
|
||||
allPresent := true
|
||||
for _, expectedPath := range expectedPaths {
|
||||
|
|
@ -269,6 +270,15 @@ func waitForExportsByPath(t *testing.T, client *http.Client, endpoint string, ex
|
|||
return nil
|
||||
}
|
||||
|
||||
func registerRuntimeUser(t *testing.T, client *http.Client, baseURL string) authLoginResponse {
|
||||
t.Helper()
|
||||
|
||||
return postJSONAuthCreated[authLoginResponse](t, client, "", baseURL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: runtimeUsername,
|
||||
Password: testPassword,
|
||||
})
|
||||
}
|
||||
|
||||
func buildControlPlaneBinary(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -21,12 +20,12 @@ var (
|
|||
errMountTargetUnavailable = errors.New("mount target is not available")
|
||||
errNodeIDMismatch = errors.New("node id path and body must match")
|
||||
errNodeNotFound = errors.New("node not found")
|
||||
errNodeOwnedByAnotherUser = errors.New("node is already owned by another user")
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
controlPlaneNodeTokenKey = "X-BetterNAS-Node-Token"
|
||||
bearerScheme = "Bearer"
|
||||
authorizationHeader = "Authorization"
|
||||
bearerScheme = "Bearer"
|
||||
)
|
||||
|
||||
func (a *app) handler() http.Handler {
|
||||
|
|
@ -76,6 +75,11 @@ func (a *app) handleVersion(w http.ResponseWriter, _ *http.Request) {
|
|||
}
|
||||
|
||||
func (a *app) handleNodeRegister(w http.ResponseWriter, r *http.Request) {
|
||||
currentUser, ok := a.requireSessionUser(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
request, err := decodeNodeRegistrationRequest(w, r)
|
||||
if err != nil {
|
||||
writeDecodeError(w, err)
|
||||
|
|
@ -87,23 +91,25 @@ func (a *app) handleNodeRegister(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if !a.authorizeNodeRegistration(w, r, request.MachineID) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.store.registerNode(request, a.now())
|
||||
result, err := a.store.registerNode(currentUser.ID, request, a.now())
|
||||
if err != nil {
|
||||
if errors.Is(err, errNodeOwnedByAnotherUser) {
|
||||
http.Error(w, err.Error(), http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if result.IssuedNodeToken != "" {
|
||||
w.Header().Set(controlPlaneNodeTokenKey, result.IssuedNodeToken)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, result.Node)
|
||||
}
|
||||
|
||||
func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
currentUser, ok := a.requireSessionUser(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nodeID := r.PathValue("nodeId")
|
||||
|
||||
var request nodeHeartbeatRequest
|
||||
|
|
@ -121,11 +127,7 @@ func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if !a.authorizeNode(w, r, nodeID) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.store.recordHeartbeat(nodeID, request); err != nil {
|
||||
if err := a.store.recordHeartbeat(nodeID, currentUser.ID, request); err != nil {
|
||||
statusCode := http.StatusInternalServerError
|
||||
if errors.Is(err, errNodeNotFound) {
|
||||
statusCode = http.StatusNotFound
|
||||
|
|
@ -138,6 +140,11 @@ func (a *app) handleNodeHeartbeat(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (a *app) handleNodeExports(w http.ResponseWriter, r *http.Request) {
|
||||
currentUser, ok := a.requireSessionUser(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nodeID := r.PathValue("nodeId")
|
||||
|
||||
request, err := decodeNodeExportsRequest(w, r)
|
||||
|
|
@ -151,11 +158,7 @@ func (a *app) handleNodeExports(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if !a.authorizeNode(w, r, nodeID) {
|
||||
return
|
||||
}
|
||||
|
||||
exports, err := a.store.upsertExports(nodeID, request)
|
||||
exports, err := a.store.upsertExports(nodeID, currentUser.ID, request)
|
||||
if err != nil {
|
||||
statusCode := http.StatusInternalServerError
|
||||
if errors.Is(err, errNodeNotFound) {
|
||||
|
|
@ -169,15 +172,17 @@ func (a *app) handleNodeExports(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (a *app) handleExportsList(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.requireClientAuth(w, r) {
|
||||
currentUser, ok := a.requireSessionUser(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, a.store.listExports())
|
||||
writeJSON(w, http.StatusOK, a.store.listExports(currentUser.ID))
|
||||
}
|
||||
|
||||
func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.requireClientAuth(w, r) {
|
||||
currentUser, ok := a.requireSessionUser(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -192,8 +197,8 @@ func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
context, ok := a.store.exportContext(request.ExportID)
|
||||
if !ok {
|
||||
context, found := a.store.exportContext(request.ExportID, currentUser.ID)
|
||||
if !found {
|
||||
http.Error(w, errExportNotFound.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
|
@ -204,32 +209,20 @@ func (a *app) handleMountProfileIssue(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
credentialID, credential, err := issueMountCredential(
|
||||
a.config.davAuthSecret,
|
||||
context.node.ID,
|
||||
mountProfilePathForExport(context.export.MountPath),
|
||||
false,
|
||||
a.now(),
|
||||
a.config.davCredentialTTL,
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, mountProfile{
|
||||
ID: credentialID,
|
||||
ID: context.export.ID,
|
||||
ExportID: context.export.ID,
|
||||
Protocol: "webdav",
|
||||
DisplayName: context.export.Label,
|
||||
MountURL: mountURL,
|
||||
Readonly: false,
|
||||
Credential: credential,
|
||||
Credential: buildAccountMountCredential(currentUser.Username),
|
||||
})
|
||||
}
|
||||
|
||||
func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.requireClientAuth(w, r) {
|
||||
currentUser, ok := a.requireSessionUser(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -244,8 +237,8 @@ func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
context, ok := a.store.exportContext(request.ExportID)
|
||||
if !ok {
|
||||
context, found := a.store.exportContext(request.ExportID, currentUser.ID)
|
||||
if !found {
|
||||
http.Error(w, errExportNotFound.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
|
@ -257,7 +250,7 @@ func (a *app) handleCloudProfileIssue(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, cloudProfile{
|
||||
ID: fmt.Sprintf("cloud-%s-%s", request.UserID, context.export.ID),
|
||||
ID: fmt.Sprintf("cloud-%s-%s", currentUser.ID, context.export.ID),
|
||||
ExportID: context.export.ID,
|
||||
Provider: "nextcloud",
|
||||
BaseURL: baseURL,
|
||||
|
|
@ -1034,71 +1027,22 @@ func corsMiddleware(allowedOrigin string, next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
// --- client auth ---
|
||||
// --- session auth ---
|
||||
|
||||
func (a *app) requireClientAuth(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (a *app) requireSessionUser(w http.ResponseWriter, r *http.Request) (user, bool) {
|
||||
presentedToken, ok := bearerToken(r)
|
||||
if !ok {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
return user{}, false
|
||||
}
|
||||
|
||||
// Session-based auth (SQLite).
|
||||
if _, err := a.store.validateSession(presentedToken); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fall back to static client token for backwards compatibility.
|
||||
if a.config.clientToken != "" && secureStringEquals(a.config.clientToken, presentedToken) {
|
||||
return true
|
||||
}
|
||||
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *app) authorizeNodeRegistration(w http.ResponseWriter, r *http.Request, machineID string) bool {
|
||||
presentedToken, ok := bearerToken(r)
|
||||
if !ok {
|
||||
currentUser, err := a.store.validateSession(presentedToken)
|
||||
if err != nil {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
return user{}, false
|
||||
}
|
||||
|
||||
authState, exists := a.store.nodeAuthByMachineID(machineID)
|
||||
if !exists || strings.TrimSpace(authState.TokenHash) == "" {
|
||||
if !secureStringEquals(a.config.nodeBootstrapToken, presentedToken) {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if !tokenHashMatches(authState.TokenHash, presentedToken) {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *app) authorizeNode(w http.ResponseWriter, r *http.Request, nodeID string) bool {
|
||||
presentedToken, ok := bearerToken(r)
|
||||
if !ok {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
|
||||
authState, exists := a.store.nodeAuthByID(nodeID)
|
||||
if !exists {
|
||||
http.Error(w, errNodeNotFound.Error(), http.StatusNotFound)
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(authState.TokenHash) == "" || !tokenHashMatches(authState.TokenHash, presentedToken) {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return currentUser, true
|
||||
}
|
||||
|
||||
func bearerToken(r *http.Request) (string, bool) {
|
||||
|
|
@ -1124,11 +1068,3 @@ func writeUnauthorized(w http.ResponseWriter) {
|
|||
w.Header().Set("WWW-Authenticate", bearerScheme)
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func secureStringEquals(expected string, actual string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
|
||||
}
|
||||
|
||||
func tokenHashMatches(expectedHash string, token string) bool {
|
||||
return secureStringEquals(expectedHash, hashOpaqueToken(token))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -15,8 +16,9 @@ import (
|
|||
var testControlPlaneNow = time.Date(2025, time.January, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
const (
|
||||
testPassword = "password123"
|
||||
testClientToken = "test-client-token"
|
||||
testNodeBootstrapToken = "test-node-bootstrap-token"
|
||||
testNodeBootstrapToken = "test-node-session-token"
|
||||
)
|
||||
|
||||
type registeredNode struct {
|
||||
|
|
@ -94,7 +96,7 @@ func TestControlPlaneRegistrationProfilesAndHeartbeat(t *testing.T) {
|
|||
RelayAddress: &relayAddress,
|
||||
})
|
||||
if registration.NodeToken == "" {
|
||||
t.Fatal("expected node registration to return a node token")
|
||||
t.Fatal("expected node registration to preserve the session token")
|
||||
}
|
||||
|
||||
syncedExports := syncNodeExports(t, server.Client(), registration.NodeToken, server.URL+"/api/v1/nodes/"+registration.Node.ID+"/exports", nodeExportsRequest{
|
||||
|
|
@ -169,14 +171,14 @@ func TestControlPlaneRegistrationProfilesAndHeartbeat(t *testing.T) {
|
|||
if mount.Credential.Mode != mountCredentialModeBasicAuth {
|
||||
t.Fatalf("expected credential mode %q, got %q", mountCredentialModeBasicAuth, mount.Credential.Mode)
|
||||
}
|
||||
if mount.Credential.Username == "" {
|
||||
t.Fatal("expected mount credential username to be set")
|
||||
if mount.Credential.Username != "fixture" {
|
||||
t.Fatalf("expected mount credential username %q, got %q", "fixture", mount.Credential.Username)
|
||||
}
|
||||
if mount.Credential.Password == "" {
|
||||
t.Fatal("expected mount credential password to be set")
|
||||
if mount.Credential.Password != "" {
|
||||
t.Fatalf("expected mount credential password to be blank, got %q", mount.Credential.Password)
|
||||
}
|
||||
if mount.Credential.ExpiresAt == "" {
|
||||
t.Fatal("expected mount credential expiry to be set")
|
||||
if mount.Credential.ExpiresAt != "" {
|
||||
t.Fatalf("expected mount credential expiry to be blank, got %q", mount.Credential.ExpiresAt)
|
||||
}
|
||||
|
||||
cloud := postJSONAuth[cloudProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/cloud-profiles/issue", cloudProfileRequest{
|
||||
|
|
@ -231,7 +233,7 @@ func TestControlPlaneExportSyncReconcilesExportsAndKeepsStableIDs(t *testing.T)
|
|||
RelayAddress: nil,
|
||||
})
|
||||
|
||||
putJSONAuthStatus(t, server.Client(), testNodeBootstrapToken, server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
|
||||
putJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{
|
||||
{
|
||||
Label: "Docs",
|
||||
|
|
@ -285,7 +287,7 @@ func TestControlPlaneExportSyncReconcilesExportsAndKeepsStableIDs(t *testing.T)
|
|||
RelayAddress: nil,
|
||||
})
|
||||
|
||||
putJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
|
||||
putJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+firstRegistration.Node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{
|
||||
{
|
||||
Label: "Docs v2",
|
||||
|
|
@ -330,8 +332,8 @@ func TestControlPlaneExportSyncReconcilesExportsAndKeepsStableIDs(t *testing.T)
|
|||
if secondRegistration.Node.ID != firstRegistration.Node.ID {
|
||||
t.Fatalf("expected re-registration to keep node ID %q, got %q", firstRegistration.Node.ID, secondRegistration.Node.ID)
|
||||
}
|
||||
if secondRegistration.NodeToken != "" {
|
||||
t.Fatalf("expected re-registration to keep the existing node token, got %q", secondRegistration.NodeToken)
|
||||
if secondRegistration.NodeToken != firstRegistration.NodeToken {
|
||||
t.Fatalf("expected re-registration to keep the existing session token %q, got %q", firstRegistration.NodeToken, secondRegistration.NodeToken)
|
||||
}
|
||||
|
||||
updatedExports := exportsByPath(getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports"))
|
||||
|
|
@ -539,12 +541,12 @@ func TestControlPlaneCloudProfilesRequireConfiguredBaseURLAndExistingExport(t *t
|
|||
func TestControlPlanePersistsRegistryAcrossAppRestart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
statePath := filepath.Join(t.TempDir(), "control-plane-state.json")
|
||||
dbPath := filepath.Join(t.TempDir(), "control-plane.db")
|
||||
directAddress := "http://nas.local:8090"
|
||||
|
||||
_, firstServer := newTestControlPlaneServer(t, appConfig{
|
||||
version: "test-version",
|
||||
statePath: statePath,
|
||||
version: "test-version",
|
||||
dbPath: dbPath,
|
||||
})
|
||||
registration := registerNode(t, firstServer.Client(), firstServer.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-persisted",
|
||||
|
|
@ -566,8 +568,8 @@ func TestControlPlanePersistsRegistryAcrossAppRestart(t *testing.T) {
|
|||
firstServer.Close()
|
||||
|
||||
_, secondServer := newTestControlPlaneServer(t, appConfig{
|
||||
version: "test-version",
|
||||
statePath: statePath,
|
||||
version: "test-version",
|
||||
dbPath: dbPath,
|
||||
})
|
||||
defer secondServer.Close()
|
||||
|
||||
|
|
@ -656,15 +658,12 @@ func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
|
|||
if err := json.NewDecoder(response.Body).Decode(&node); err != nil {
|
||||
t.Fatalf("decode registration response: %v", err)
|
||||
}
|
||||
nodeToken := strings.TrimSpace(response.Header.Get(controlPlaneNodeTokenKey))
|
||||
if nodeToken == "" {
|
||||
t.Fatal("expected node registration to return a node token")
|
||||
}
|
||||
nodeToken := testNodeBootstrapToken
|
||||
if node.ID != "dev-node" {
|
||||
t.Fatalf("expected node ID %q, got %q", "dev-node", node.ID)
|
||||
}
|
||||
|
||||
putJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+node.ID+"/exports", nodeExportsRequest{
|
||||
putJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{{
|
||||
Label: "Docs",
|
||||
Path: "/srv/docs",
|
||||
|
|
@ -716,7 +715,7 @@ func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
|
|||
},
|
||||
}, http.StatusBadRequest)
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
|
||||
postJSONAuthStatus(t, server.Client(), "wrong-session-token", server.URL+"/api/v1/nodes/"+node.ID+"/heartbeat", nodeHeartbeatRequest{
|
||||
NodeID: node.ID,
|
||||
Status: "online",
|
||||
LastSeenAt: "2025-01-02T03:04:05Z",
|
||||
|
|
@ -765,21 +764,12 @@ func TestControlPlaneRejectsInvalidRequestsAndEnforcesAuth(t *testing.T) {
|
|||
func newTestControlPlaneServer(t *testing.T, config appConfig) (*app, *httptest.Server) {
|
||||
t.Helper()
|
||||
|
||||
if config.dbPath == "" {
|
||||
config.dbPath = filepath.Join(t.TempDir(), "test.db")
|
||||
}
|
||||
if config.version == "" {
|
||||
config.version = "test-version"
|
||||
}
|
||||
if config.clientToken == "" {
|
||||
config.clientToken = testClientToken
|
||||
}
|
||||
if config.nodeBootstrapToken == "" {
|
||||
config.nodeBootstrapToken = testNodeBootstrapToken
|
||||
}
|
||||
if config.davAuthSecret == "" {
|
||||
config.davAuthSecret = "test-dav-auth-secret"
|
||||
}
|
||||
if config.davCredentialTTL == 0 {
|
||||
config.davCredentialTTL = time.Hour
|
||||
}
|
||||
|
||||
app, err := newApp(config, testControlPlaneNow)
|
||||
if err != nil {
|
||||
|
|
@ -788,11 +778,46 @@ func newTestControlPlaneServer(t *testing.T, config appConfig) (*app, *httptest.
|
|||
app.now = func() time.Time {
|
||||
return testControlPlaneNow
|
||||
}
|
||||
seedDefaultSessionUser(t, app)
|
||||
|
||||
server := httptest.NewServer(app.handler())
|
||||
return app, server
|
||||
}
|
||||
|
||||
func seedDefaultSessionUser(t *testing.T, app *app) {
|
||||
t.Helper()
|
||||
|
||||
u, err := app.store.createUser("fixture", testPassword)
|
||||
if err != nil && !errors.Is(err, errUsernameTaken) {
|
||||
t.Fatalf("seed default test user: %v", err)
|
||||
}
|
||||
if errors.Is(err, errUsernameTaken) {
|
||||
u, err = app.store.authenticateUser("fixture", testPassword)
|
||||
if err != nil {
|
||||
t.Fatalf("authenticate seeded test user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sqliteStore, ok := app.store.(*sqliteStore)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
createdAt := time.Now().UTC().Format(time.RFC3339)
|
||||
expiresAt := time.Now().UTC().Add(24 * time.Hour).Format(time.RFC3339)
|
||||
for _, token := range []string{testClientToken, testNodeBootstrapToken} {
|
||||
if _, err := sqliteStore.db.Exec(
|
||||
"INSERT OR REPLACE INTO sessions (token, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
|
||||
token,
|
||||
u.ID,
|
||||
createdAt,
|
||||
expiresAt,
|
||||
); err != nil {
|
||||
t.Fatalf("seed session %s: %v", token, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func exportsByPath(exports []storageExport) map[string]storageExport {
|
||||
byPath := make(map[string]storageExport, len(exports))
|
||||
for _, export := range exports {
|
||||
|
|
@ -820,10 +845,19 @@ func registerNode(t *testing.T, client *http.Client, endpoint string, token stri
|
|||
|
||||
return registeredNode{
|
||||
Node: node,
|
||||
NodeToken: strings.TrimSpace(response.Header.Get(controlPlaneNodeTokenKey)),
|
||||
NodeToken: strings.TrimSpace(token),
|
||||
}
|
||||
}
|
||||
|
||||
func registerSessionUser(t *testing.T, client *http.Client, baseURL string, username string) authLoginResponse {
|
||||
t.Helper()
|
||||
|
||||
return postJSONAuthCreated[authLoginResponse](t, client, "", baseURL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: username,
|
||||
Password: testPassword,
|
||||
})
|
||||
}
|
||||
|
||||
func syncNodeExports(t *testing.T, client *http.Client, token string, endpoint string, payload nodeExportsRequest) []storageExport {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
errUsernameTaken = errors.New("username already taken")
|
||||
errInvalidLogin = errors.New("invalid username or password")
|
||||
errUsernameTaken = errors.New("username already taken")
|
||||
errInvalidLogin = errors.New("invalid username or password")
|
||||
errSessionExpired = errors.New("session expired or invalid")
|
||||
)
|
||||
|
||||
|
|
@ -32,6 +32,7 @@ INSERT OR IGNORE INTO ordinals (name, value) VALUES ('node', 0), ('export', 0);
|
|||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
id TEXT PRIMARY KEY,
|
||||
machine_id TEXT NOT NULL UNIQUE,
|
||||
owner_id TEXT REFERENCES users(id),
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
agent_version TEXT NOT NULL DEFAULT '',
|
||||
status TEXT NOT NULL DEFAULT 'online',
|
||||
|
|
@ -48,6 +49,7 @@ CREATE TABLE IF NOT EXISTS node_tokens (
|
|||
CREATE TABLE IF NOT EXISTS exports (
|
||||
id TEXT PRIMARY KEY,
|
||||
node_id TEXT NOT NULL REFERENCES nodes(id),
|
||||
owner_id TEXT REFERENCES users(id),
|
||||
label TEXT NOT NULL DEFAULT '',
|
||||
path TEXT NOT NULL,
|
||||
mount_path TEXT NOT NULL DEFAULT '',
|
||||
|
|
@ -101,10 +103,40 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
|||
db.Close()
|
||||
return nil, fmt.Errorf("initialize database schema: %w", err)
|
||||
}
|
||||
if err := migrateSQLiteSchema(db); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sqliteStore{db: db}, nil
|
||||
}
|
||||
|
||||
func migrateSQLiteSchema(db *sql.DB) error {
|
||||
migrations := []string{
|
||||
"ALTER TABLE nodes ADD COLUMN owner_id TEXT REFERENCES users(id)",
|
||||
"ALTER TABLE exports ADD COLUMN owner_id TEXT REFERENCES users(id)",
|
||||
}
|
||||
for _, statement := range migrations {
|
||||
if _, err := db.Exec(statement); err != nil && !strings.Contains(err.Error(), "duplicate column name") {
|
||||
return fmt.Errorf("run sqlite migration %q: %w", statement, err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`
|
||||
UPDATE exports
|
||||
SET owner_id = (
|
||||
SELECT owner_id
|
||||
FROM nodes
|
||||
WHERE nodes.id = exports.node_id
|
||||
)
|
||||
WHERE owner_id IS NULL
|
||||
`); err != nil {
|
||||
return fmt.Errorf("backfill export owners: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) nextOrdinal(tx *sql.Tx, name string) (int, error) {
|
||||
var value int
|
||||
err := tx.QueryRow("UPDATE ordinals SET value = value + 1 WHERE name = ? RETURNING value", name).Scan(&value)
|
||||
|
|
@ -128,7 +160,7 @@ func ordinalToExportID(ordinal int) string {
|
|||
return fmt.Sprintf("dev-export-%d", ordinal)
|
||||
}
|
||||
|
||||
func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
func (s *sqliteStore) registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("begin transaction: %w", err)
|
||||
|
|
@ -137,7 +169,8 @@ func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt
|
|||
|
||||
// Check if machine already registered.
|
||||
var nodeID string
|
||||
err = tx.QueryRow("SELECT id FROM nodes WHERE machine_id = ?", request.MachineID).Scan(&nodeID)
|
||||
var existingOwnerID sql.NullString
|
||||
err = tx.QueryRow("SELECT id, owner_id FROM nodes WHERE machine_id = ?", request.MachineID).Scan(&nodeID, &existingOwnerID)
|
||||
if err == sql.ErrNoRows {
|
||||
ordinal, err := s.nextOrdinal(tx, "node")
|
||||
if err != nil {
|
||||
|
|
@ -146,57 +179,40 @@ func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt
|
|||
nodeID = ordinalToNodeID(ordinal)
|
||||
} else if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("lookup node by machine_id: %w", err)
|
||||
} else if existingOwnerID.Valid && strings.TrimSpace(existingOwnerID.String) != "" && existingOwnerID.String != ownerID {
|
||||
return nodeRegistrationResult{}, errNodeOwnedByAnotherUser
|
||||
}
|
||||
|
||||
// Upsert node.
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO nodes (id, machine_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address)
|
||||
VALUES (?, ?, ?, ?, 'online', ?, ?, ?)
|
||||
INSERT INTO nodes (id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address)
|
||||
VALUES (?, ?, ?, ?, ?, 'online', ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
owner_id = excluded.owner_id,
|
||||
display_name = excluded.display_name,
|
||||
agent_version = excluded.agent_version,
|
||||
status = 'online',
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
direct_address = excluded.direct_address,
|
||||
relay_address = excluded.relay_address
|
||||
`, nodeID, request.MachineID, request.DisplayName, request.AgentVersion,
|
||||
`, nodeID, request.MachineID, ownerID, request.DisplayName, request.AgentVersion,
|
||||
registeredAt.UTC().Format(time.RFC3339),
|
||||
nullableString(request.DirectAddress), nullableString(request.RelayAddress))
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("upsert node: %w", err)
|
||||
}
|
||||
|
||||
// Issue token if none exists.
|
||||
var issuedNodeToken string
|
||||
var existingHash sql.NullString
|
||||
_ = tx.QueryRow("SELECT token_hash FROM node_tokens WHERE node_id = ?", nodeID).Scan(&existingHash)
|
||||
|
||||
if !existingHash.Valid || strings.TrimSpace(existingHash.String) == "" {
|
||||
nodeToken, err := newOpaqueToken()
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, err
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT OR REPLACE INTO node_tokens (node_id, token_hash) VALUES (?, ?)",
|
||||
nodeID, hashOpaqueToken(nodeToken))
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("store node token: %w", err)
|
||||
}
|
||||
issuedNodeToken = nodeToken
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("commit registration: %w", err)
|
||||
}
|
||||
|
||||
node, _ := s.nodeByID(nodeID)
|
||||
return nodeRegistrationResult{
|
||||
Node: node,
|
||||
IssuedNodeToken: issuedNodeToken,
|
||||
Node: node,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error) {
|
||||
func (s *sqliteStore) upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error) {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||||
|
|
@ -205,7 +221,7 @@ func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) (
|
|||
|
||||
// Verify node exists.
|
||||
var exists bool
|
||||
err = tx.QueryRow("SELECT 1 FROM nodes WHERE id = ?", nodeID).Scan(&exists)
|
||||
err = tx.QueryRow("SELECT 1 FROM nodes WHERE id = ? AND owner_id = ?", nodeID, ownerID).Scan(&exists)
|
||||
if err != nil {
|
||||
return nil, errNodeNotFound
|
||||
}
|
||||
|
|
@ -238,13 +254,14 @@ func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) (
|
|||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO exports (id, node_id, label, path, mount_path, capacity_bytes)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO exports (id, node_id, owner_id, label, path, mount_path, capacity_bytes)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
owner_id = excluded.owner_id,
|
||||
label = excluded.label,
|
||||
mount_path = excluded.mount_path,
|
||||
capacity_bytes = excluded.capacity_bytes
|
||||
`, exportID, nodeID, input.Label, input.Path, input.MountPath, nullableInt64(input.CapacityBytes))
|
||||
`, exportID, nodeID, ownerID, input.Label, input.Path, input.MountPath, nullableInt64(input.CapacityBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upsert export %q: %w", input.Path, err)
|
||||
}
|
||||
|
|
@ -288,10 +305,10 @@ func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) (
|
|||
return s.listExportsForNode(nodeID), nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error {
|
||||
func (s *sqliteStore) recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error {
|
||||
result, err := s.db.Exec(
|
||||
"UPDATE nodes SET status = ?, last_seen_at = ? WHERE id = ?",
|
||||
request.Status, request.LastSeenAt, nodeID)
|
||||
"UPDATE nodes SET status = ?, last_seen_at = ? WHERE id = ? AND owner_id = ?",
|
||||
request.Status, request.LastSeenAt, nodeID, ownerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update heartbeat: %w", err)
|
||||
}
|
||||
|
|
@ -302,8 +319,8 @@ func (s *sqliteStore) recordHeartbeat(nodeID string, request nodeHeartbeatReques
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) listExports() []storageExport {
|
||||
rows, err := s.db.Query("SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports ORDER BY id")
|
||||
func (s *sqliteStore) listExports(ownerID string) []storageExport {
|
||||
rows, err := s.db.Query("SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE owner_id = ? ORDER BY id", ownerID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -330,7 +347,7 @@ func (s *sqliteStore) listExports() []storageExport {
|
|||
}
|
||||
|
||||
func (s *sqliteStore) listExportsForNode(nodeID string) []storageExport {
|
||||
rows, err := s.db.Query("SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports WHERE node_id = ? ORDER BY id", nodeID)
|
||||
rows, err := s.db.Query("SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE node_id = ? ORDER BY id", nodeID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -356,15 +373,19 @@ func (s *sqliteStore) listExportsForNode(nodeID string) []storageExport {
|
|||
return exports
|
||||
}
|
||||
|
||||
func (s *sqliteStore) exportContext(exportID string) (exportContext, bool) {
|
||||
func (s *sqliteStore) exportContext(exportID string, ownerID string) (exportContext, bool) {
|
||||
var e storageExport
|
||||
var capacityBytes sql.NullInt64
|
||||
var exportOwnerID sql.NullString
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports WHERE id = ?",
|
||||
exportID).Scan(&e.ID, &e.NasNodeID, &e.Label, &e.Path, &e.MountPath, &capacityBytes)
|
||||
"SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE id = ? AND owner_id = ?",
|
||||
exportID, ownerID).Scan(&e.ID, &e.NasNodeID, &exportOwnerID, &e.Label, &e.Path, &e.MountPath, &capacityBytes)
|
||||
if err != nil {
|
||||
return exportContext{}, false
|
||||
}
|
||||
if exportOwnerID.Valid {
|
||||
e.OwnerID = exportOwnerID.String
|
||||
}
|
||||
if capacityBytes.Valid {
|
||||
e.CapacityBytes = &capacityBytes.Int64
|
||||
}
|
||||
|
|
@ -383,12 +404,16 @@ func (s *sqliteStore) nodeByID(nodeID string) (nasNode, bool) {
|
|||
var n nasNode
|
||||
var directAddr, relayAddr sql.NullString
|
||||
var lastSeenAt sql.NullString
|
||||
var ownerID sql.NullString
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, machine_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE id = ?",
|
||||
nodeID).Scan(&n.ID, &n.MachineID, &n.DisplayName, &n.AgentVersion, &n.Status, &lastSeenAt, &directAddr, &relayAddr)
|
||||
"SELECT id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE id = ?",
|
||||
nodeID).Scan(&n.ID, &n.MachineID, &ownerID, &n.DisplayName, &n.AgentVersion, &n.Status, &lastSeenAt, &directAddr, &relayAddr)
|
||||
if err != nil {
|
||||
return nasNode{}, false
|
||||
}
|
||||
if ownerID.Valid {
|
||||
n.OwnerID = ownerID.String
|
||||
}
|
||||
if lastSeenAt.Valid {
|
||||
n.LastSeenAt = lastSeenAt.String
|
||||
}
|
||||
|
|
@ -442,9 +467,13 @@ func (s *sqliteStore) nodeAuthByID(nodeID string) (nodeAuthState, bool) {
|
|||
func (s *sqliteStore) scanExport(rows *sql.Rows) storageExport {
|
||||
var e storageExport
|
||||
var capacityBytes sql.NullInt64
|
||||
if err := rows.Scan(&e.ID, &e.NasNodeID, &e.Label, &e.Path, &e.MountPath, &capacityBytes); err != nil {
|
||||
var ownerID sql.NullString
|
||||
if err := rows.Scan(&e.ID, &e.NasNodeID, &ownerID, &e.Label, &e.Path, &e.MountPath, &capacityBytes); err != nil {
|
||||
return storageExport{}
|
||||
}
|
||||
if ownerID.Valid {
|
||||
e.OwnerID = ownerID.String
|
||||
}
|
||||
if capacityBytes.Valid {
|
||||
e.CapacityBytes = &capacityBytes.Int64
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,24 +18,13 @@ func newTestSQLiteApp(t *testing.T, config appConfig) (*app, *httptest.Server) {
|
|||
if config.version == "" {
|
||||
config.version = "test-version"
|
||||
}
|
||||
if config.clientToken == "" {
|
||||
config.clientToken = testClientToken
|
||||
}
|
||||
if config.nodeBootstrapToken == "" {
|
||||
config.nodeBootstrapToken = testNodeBootstrapToken
|
||||
}
|
||||
if config.davAuthSecret == "" {
|
||||
config.davAuthSecret = "test-dav-auth-secret"
|
||||
}
|
||||
if config.davCredentialTTL == 0 {
|
||||
config.davCredentialTTL = time.Hour
|
||||
}
|
||||
|
||||
app, err := newApp(config, testControlPlaneNow)
|
||||
if err != nil {
|
||||
t.Fatalf("new app: %v", err)
|
||||
}
|
||||
app.now = func() time.Time { return testControlPlaneNow }
|
||||
seedDefaultSessionUser(t, app)
|
||||
|
||||
server := httptest.NewServer(app.handler())
|
||||
return app, server
|
||||
|
|
@ -79,7 +68,7 @@ func TestSQLiteRegistrationAndExports(t *testing.T) {
|
|||
RelayAddress: nil,
|
||||
})
|
||||
if registration.NodeToken == "" {
|
||||
t.Fatal("expected node registration to return a node token")
|
||||
t.Fatal("expected node registration to preserve the session token")
|
||||
}
|
||||
if registration.Node.ID != "dev-node" {
|
||||
t.Fatalf("expected node ID %q, got %q", "dev-node", registration.Node.ID)
|
||||
|
|
@ -142,8 +131,8 @@ func TestSQLiteReRegistrationKeepsNodeID(t *testing.T) {
|
|||
if second.Node.ID != first.Node.ID {
|
||||
t.Fatalf("expected re-registration to keep node ID %q, got %q", first.Node.ID, second.Node.ID)
|
||||
}
|
||||
if second.NodeToken != "" {
|
||||
t.Fatalf("expected re-registration to not issue new token, got %q", second.NodeToken)
|
||||
if second.NodeToken != first.NodeToken {
|
||||
t.Fatalf("expected re-registration to keep the existing session token %q, got %q", first.NodeToken, second.NodeToken)
|
||||
}
|
||||
if second.Node.DisplayName != "NAS Updated" {
|
||||
t.Fatalf("expected updated display name, got %q", second.Node.DisplayName)
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ type memoryStore struct {
|
|||
}
|
||||
|
||||
type nodeRegistrationResult struct {
|
||||
Node nasNode
|
||||
IssuedNodeToken string
|
||||
Node nasNode
|
||||
}
|
||||
|
||||
type nodeAuthState struct {
|
||||
|
|
@ -153,12 +152,12 @@ func cloneStoreState(state storeState) storeState {
|
|||
return cloned
|
||||
}
|
||||
|
||||
func (s *memoryStore) registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
func (s *memoryStore) registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
nextState := cloneStoreState(s.state)
|
||||
result, err := registerNodeInState(&nextState, request, registeredAt)
|
||||
result, err := registerNodeInState(&nextState, ownerID, request, registeredAt)
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, err
|
||||
}
|
||||
|
|
@ -170,21 +169,14 @@ func (s *memoryStore) registerNode(request nodeRegistrationRequest, registeredAt
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func registerNodeInState(state *storeState, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
func registerNodeInState(state *storeState, ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
nodeID, ok := state.NodeIDByMachineID[request.MachineID]
|
||||
if !ok {
|
||||
nodeID = nextNodeID(state)
|
||||
state.NodeIDByMachineID[request.MachineID] = nodeID
|
||||
}
|
||||
|
||||
issuedNodeToken := ""
|
||||
if stringsTrimmedEmpty(state.NodeTokenHashByID[nodeID]) {
|
||||
nodeToken, err := newOpaqueToken()
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, err
|
||||
}
|
||||
state.NodeTokenHashByID[nodeID] = hashOpaqueToken(nodeToken)
|
||||
issuedNodeToken = nodeToken
|
||||
if existingNode, exists := state.NodesByID[nodeID]; exists && existingNode.OwnerID != "" && existingNode.OwnerID != ownerID {
|
||||
return nodeRegistrationResult{}, errNodeOwnedByAnotherUser
|
||||
}
|
||||
|
||||
node := nasNode{
|
||||
|
|
@ -196,21 +188,21 @@ func registerNodeInState(state *storeState, request nodeRegistrationRequest, reg
|
|||
LastSeenAt: registeredAt.UTC().Format(time.RFC3339),
|
||||
DirectAddress: copyStringPointer(request.DirectAddress),
|
||||
RelayAddress: copyStringPointer(request.RelayAddress),
|
||||
OwnerID: ownerID,
|
||||
}
|
||||
|
||||
state.NodesByID[nodeID] = node
|
||||
return nodeRegistrationResult{
|
||||
Node: node,
|
||||
IssuedNodeToken: issuedNodeToken,
|
||||
Node: node,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error) {
|
||||
func (s *memoryStore) upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
nextState := cloneStoreState(s.state)
|
||||
exports, err := upsertExportsInState(&nextState, nodeID, request.Exports)
|
||||
exports, err := upsertExportsInState(&nextState, nodeID, ownerID, request.Exports)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -222,8 +214,9 @@ func (s *memoryStore) upsertExports(nodeID string, request nodeExportsRequest) (
|
|||
return exports, nil
|
||||
}
|
||||
|
||||
func upsertExportsInState(state *storeState, nodeID string, exports []storageExportInput) ([]storageExport, error) {
|
||||
if _, ok := state.NodesByID[nodeID]; !ok {
|
||||
func upsertExportsInState(state *storeState, nodeID string, ownerID string, exports []storageExportInput) ([]storageExport, error) {
|
||||
node, ok := state.NodesByID[nodeID]
|
||||
if !ok || node.OwnerID != ownerID {
|
||||
return nil, errNodeNotFound
|
||||
}
|
||||
|
||||
|
|
@ -250,6 +243,7 @@ func upsertExportsInState(state *storeState, nodeID string, exports []storageExp
|
|||
Protocols: copyStringSlice(export.Protocols),
|
||||
CapacityBytes: copyInt64Pointer(export.CapacityBytes),
|
||||
Tags: copyStringSlice(export.Tags),
|
||||
OwnerID: ownerID,
|
||||
}
|
||||
keepPaths[export.Path] = struct{}{}
|
||||
}
|
||||
|
|
@ -278,12 +272,12 @@ func upsertExportsInState(state *storeState, nodeID string, exports []storageExp
|
|||
return nodeExports, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error {
|
||||
func (s *memoryStore) recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
nextState := cloneStoreState(s.state)
|
||||
if err := recordHeartbeatInState(&nextState, nodeID, request); err != nil {
|
||||
if err := recordHeartbeatInState(&nextState, nodeID, ownerID, request); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.persistLocked(nextState); err != nil {
|
||||
|
|
@ -294,9 +288,9 @@ func (s *memoryStore) recordHeartbeat(nodeID string, request nodeHeartbeatReques
|
|||
return nil
|
||||
}
|
||||
|
||||
func recordHeartbeatInState(state *storeState, nodeID string, request nodeHeartbeatRequest) error {
|
||||
func recordHeartbeatInState(state *storeState, nodeID string, ownerID string, request nodeHeartbeatRequest) error {
|
||||
node, ok := state.NodesByID[nodeID]
|
||||
if !ok {
|
||||
if !ok || node.OwnerID != ownerID {
|
||||
return errNodeNotFound
|
||||
}
|
||||
|
||||
|
|
@ -307,12 +301,15 @@ func recordHeartbeatInState(state *storeState, nodeID string, request nodeHeartb
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) listExports() []storageExport {
|
||||
func (s *memoryStore) listExports(ownerID string) []storageExport {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
exports := make([]storageExport, 0, len(s.state.ExportsByID))
|
||||
for _, export := range s.state.ExportsByID {
|
||||
if export.OwnerID != ownerID {
|
||||
continue
|
||||
}
|
||||
exports = append(exports, copyStorageExport(export))
|
||||
}
|
||||
|
||||
|
|
@ -323,17 +320,17 @@ func (s *memoryStore) listExports() []storageExport {
|
|||
return exports
|
||||
}
|
||||
|
||||
func (s *memoryStore) exportContext(exportID string) (exportContext, bool) {
|
||||
func (s *memoryStore) exportContext(exportID string, ownerID string) (exportContext, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
export, ok := s.state.ExportsByID[exportID]
|
||||
if !ok {
|
||||
if !ok || export.OwnerID != ownerID {
|
||||
return exportContext{}, false
|
||||
}
|
||||
|
||||
node, ok := s.state.NodesByID[export.NasNodeID]
|
||||
if !ok {
|
||||
if !ok || node.OwnerID != ownerID {
|
||||
return exportContext{}, false
|
||||
}
|
||||
|
||||
|
|
@ -468,6 +465,7 @@ func copyNasNode(node nasNode) nasNode {
|
|||
LastSeenAt: node.LastSeenAt,
|
||||
DirectAddress: copyStringPointer(node.DirectAddress),
|
||||
RelayAddress: copyStringPointer(node.RelayAddress),
|
||||
OwnerID: node.OwnerID,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -481,6 +479,7 @@ func copyStorageExport(export storageExport) storageExport {
|
|||
Protocols: copyStringSlice(export.Protocols),
|
||||
CapacityBytes: copyInt64Pointer(export.CapacityBytes),
|
||||
Tags: copyStringSlice(export.Tags),
|
||||
OwnerID: export.OwnerID,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,12 @@ import "time"
|
|||
// store defines the persistence interface for the control-plane.
|
||||
type store interface {
|
||||
// Node management
|
||||
registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error)
|
||||
upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error)
|
||||
recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error
|
||||
listExports() []storageExport
|
||||
exportContext(exportID string) (exportContext, bool)
|
||||
registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error)
|
||||
upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error)
|
||||
recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error
|
||||
listExports(ownerID string) []storageExport
|
||||
exportContext(exportID string, ownerID string) (exportContext, bool)
|
||||
nodeByID(nodeID string) (nasNode, bool)
|
||||
nodeAuthByMachineID(machineID string) (nodeAuthState, bool)
|
||||
nodeAuthByID(nodeID string) (nodeAuthState, bool)
|
||||
|
||||
// User auth
|
||||
createUser(username string, password string) (user, error)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue