mirror of
https://github.com/harivansh-afk/betterNAS.git
synced 2026-04-17 03:03:44 +00:00
Add SQLite store and user auth for production deployment
Replace the in-memory JSON-backed store with a SQLite option using modernc.org/sqlite (pure Go, no CGo). Add user authentication with bcrypt password hashing and random session tokens. SQLite store: - Schema covers nodes, exports, tokens, ordinals, users, sessions - WAL mode and foreign keys enabled - Set BETTERNAS_CONTROL_PLANE_DB_PATH to activate (falls back to memory store when empty) User auth: - POST /api/v1/auth/register, login, logout, GET /me - bcrypt (cost 10) for passwords, 32-byte hex session tokens - Sessions stored in SQLite with configurable TTL - Client endpoints accept session tokens or static client token - CORS middleware via BETTERNAS_CORS_ORIGIN env var New env vars: BETTERNAS_CONTROL_PLANE_DB_PATH, BETTERNAS_SESSION_TTL, BETTERNAS_REGISTRATION_ENABLED, BETTERNAS_CORS_ORIGIN 24 tests pass including 7 SQLite store tests and 7 auth tests. Builds clean with CGO_ENABLED=0.
This commit is contained in:
parent
1bb065ade0
commit
c499e46a4d
12 changed files with 2138 additions and 28 deletions
|
|
@ -7,27 +7,28 @@ import (
|
|||
)
|
||||
|
||||
type appConfig struct {
|
||||
version string
|
||||
nextcloudBaseURL string
|
||||
statePath string
|
||||
clientToken string
|
||||
nodeBootstrapToken string
|
||||
davAuthSecret string
|
||||
davCredentialTTL time.Duration
|
||||
version string
|
||||
nextcloudBaseURL string
|
||||
statePath string
|
||||
dbPath string
|
||||
clientToken string
|
||||
nodeBootstrapToken string
|
||||
davAuthSecret string
|
||||
davCredentialTTL time.Duration
|
||||
sessionTTL time.Duration
|
||||
registrationEnabled bool
|
||||
corsOrigin string
|
||||
}
|
||||
|
||||
type app struct {
|
||||
startedAt time.Time
|
||||
now func() time.Time
|
||||
config appConfig
|
||||
store *memoryStore
|
||||
store store
|
||||
}
|
||||
|
||||
func newApp(config appConfig, startedAt time.Time) (*app, error) {
|
||||
config.clientToken = strings.TrimSpace(config.clientToken)
|
||||
if config.clientToken == "" {
|
||||
return nil, errors.New("client token is required")
|
||||
}
|
||||
|
||||
config.nodeBootstrapToken = strings.TrimSpace(config.nodeBootstrapToken)
|
||||
if config.nodeBootstrapToken == "" {
|
||||
|
|
@ -42,7 +43,13 @@ func newApp(config appConfig, startedAt time.Time) (*app, error) {
|
|||
return nil, errors.New("dav credential ttl must be greater than 0")
|
||||
}
|
||||
|
||||
store, err := newMemoryStore(config.statePath)
|
||||
var s store
|
||||
var err error
|
||||
if config.dbPath != "" {
|
||||
s, err = newSQLiteStore(config.dbPath)
|
||||
} else {
|
||||
s, err = newMemoryStore(config.statePath)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -51,7 +58,7 @@ func newApp(config appConfig, startedAt time.Time) (*app, error) {
|
|||
startedAt: startedAt,
|
||||
now: time.Now,
|
||||
config: config,
|
||||
store: store,
|
||||
store: s,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -164,6 +171,12 @@ type exportContext struct {
|
|||
node nasNode
|
||||
}
|
||||
|
||||
type user struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}
|
||||
|
||||
func copyStringPointer(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
|
|
|
|||
216
apps/control-plane/cmd/control-plane/auth_test.go
Normal file
216
apps/control-plane/cmd/control-plane/auth_test.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func postJSONAuthCreated[T any](t *testing.T, client *http.Client, token string, endpoint string, payload any) T {
|
||||
t.Helper()
|
||||
|
||||
response := postJSONAuthResponse(t, client, token, endpoint, payload)
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusCreated {
|
||||
responseBody, _ := io.ReadAll(response.Body)
|
||||
t.Fatalf("post %s: expected status 201, got %d: %s", endpoint, response.StatusCode, responseBody)
|
||||
}
|
||||
|
||||
var decoded T
|
||||
if err := json.NewDecoder(response.Body).Decode(&decoded); err != nil {
|
||||
t.Fatalf("decode %s response: %v", endpoint, err)
|
||||
}
|
||||
|
||||
return decoded
|
||||
}
|
||||
|
||||
func TestAuthRegisterLoginLogoutMe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
registrationEnabled: true,
|
||||
sessionTTL: time.Hour,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Register.
|
||||
reg := postJSONAuthCreated[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
})
|
||||
if reg.Token == "" {
|
||||
t.Fatal("expected session token from registration")
|
||||
}
|
||||
if reg.User.Username != "testuser" {
|
||||
t.Fatalf("expected username %q, got %q", "testuser", reg.User.Username)
|
||||
}
|
||||
if reg.User.ID == "" {
|
||||
t.Fatal("expected user ID")
|
||||
}
|
||||
|
||||
// /me with the registration token.
|
||||
me := getJSONAuth[user](t, server.Client(), reg.Token, server.URL+"/api/v1/auth/me")
|
||||
if me.Username != "testuser" {
|
||||
t.Fatalf("expected username %q from /me, got %q", "testuser", me.Username)
|
||||
}
|
||||
|
||||
// Use session to list exports (client auth).
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), reg.Token, server.URL+"/api/v1/exports")
|
||||
if len(exports) != 0 {
|
||||
t.Fatalf("expected 0 exports, got %d", len(exports))
|
||||
}
|
||||
|
||||
// Login with same credentials.
|
||||
login := postJSONAuth[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/login", authLoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
})
|
||||
if login.Token == "" {
|
||||
t.Fatal("expected session token from login")
|
||||
}
|
||||
if login.Token == reg.Token {
|
||||
t.Fatal("expected login to issue a different token than registration")
|
||||
}
|
||||
|
||||
// Logout the registration token.
|
||||
postJSONAuthStatus(t, server.Client(), reg.Token, server.URL+"/api/v1/auth/logout", nil, http.StatusNoContent)
|
||||
|
||||
// Old token should be invalid now.
|
||||
getStatusWithAuth(t, server.Client(), reg.Token, server.URL+"/api/v1/auth/me", http.StatusUnauthorized)
|
||||
|
||||
// Login token still works.
|
||||
me = getJSONAuth[user](t, server.Client(), login.Token, server.URL+"/api/v1/auth/me")
|
||||
if me.Username != "testuser" {
|
||||
t.Fatalf("expected username %q, got %q", "testuser", me.Username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthDuplicateUsername(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
registrationEnabled: true,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
postJSONAuthCreated[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "taken",
|
||||
Password: "password123",
|
||||
})
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "taken",
|
||||
Password: "different456",
|
||||
}, http.StatusConflict)
|
||||
}
|
||||
|
||||
func TestAuthBadCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
registrationEnabled: true,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
postJSONAuthCreated[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "realuser",
|
||||
Password: "correctpassword",
|
||||
})
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), "", server.URL+"/api/v1/auth/login", authLoginRequest{
|
||||
Username: "realuser",
|
||||
Password: "wrongpassword",
|
||||
}, http.StatusUnauthorized)
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), "", server.URL+"/api/v1/auth/login", authLoginRequest{
|
||||
Username: "nosuchuser",
|
||||
Password: "anything",
|
||||
}, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func TestAuthRegistrationDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
registrationEnabled: false,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "blocked",
|
||||
Password: "password123",
|
||||
}, http.StatusForbidden)
|
||||
}
|
||||
|
||||
func TestAuthValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
registrationEnabled: true,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Username too short.
|
||||
postJSONAuthStatus(t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "ab",
|
||||
Password: "password123",
|
||||
}, http.StatusBadRequest)
|
||||
|
||||
// Password too short.
|
||||
postJSONAuthStatus(t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "validuser",
|
||||
Password: "short",
|
||||
}, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestAuthSessionUsedForClientEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
registrationEnabled: true,
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Without auth, exports should fail.
|
||||
getStatusWithAuth(t, server.Client(), "", server.URL+"/api/v1/exports", http.StatusUnauthorized)
|
||||
|
||||
// Register and get session.
|
||||
reg := postJSONAuthCreated[authLoginResponse](t, server.Client(), "", server.URL+"/api/v1/auth/register", authRegisterRequest{
|
||||
Username: "admin",
|
||||
Password: "password123",
|
||||
})
|
||||
|
||||
// Session should work for client endpoints.
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), reg.Token, server.URL+"/api/v1/exports")
|
||||
if exports == nil {
|
||||
t.Fatal("expected exports list, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthStaticTokenFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
clientToken: "static-fallback-token",
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Static token should work for client endpoints.
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), "static-fallback-token", server.URL+"/api/v1/exports")
|
||||
if exports == nil {
|
||||
t.Fatal("expected exports list, got nil")
|
||||
}
|
||||
|
||||
// Wrong token should fail.
|
||||
getStatusWithAuth(t, server.Client(), "wrong", server.URL+"/api/v1/exports", http.StatusUnauthorized)
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -24,11 +25,6 @@ func main() {
|
|||
}
|
||||
|
||||
func newAppFromEnv(startedAt time.Time) (*app, error) {
|
||||
clientToken, err := requiredEnv("BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeBootstrapToken, err := requiredEnv("BETTERNAS_CONTROL_PLANE_NODE_BOOTSTRAP_TOKEN")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -44,15 +40,28 @@ func newAppFromEnv(startedAt time.Time) (*app, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var sessionTTL time.Duration
|
||||
rawSessionTTL := strings.TrimSpace(env("BETTERNAS_SESSION_TTL", "720h"))
|
||||
if rawSessionTTL != "" {
|
||||
sessionTTL, err = time.ParseDuration(rawSessionTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return newApp(
|
||||
appConfig{
|
||||
version: env("BETTERNAS_VERSION", "0.1.0-dev"),
|
||||
nextcloudBaseURL: env("NEXTCLOUD_BASE_URL", ""),
|
||||
statePath: env("BETTERNAS_CONTROL_PLANE_STATE_PATH", ".state/control-plane/state.json"),
|
||||
clientToken: clientToken,
|
||||
nodeBootstrapToken: nodeBootstrapToken,
|
||||
davAuthSecret: davAuthSecret,
|
||||
davCredentialTTL: davCredentialTTL,
|
||||
version: env("BETTERNAS_VERSION", "0.1.0-dev"),
|
||||
nextcloudBaseURL: env("NEXTCLOUD_BASE_URL", ""),
|
||||
statePath: env("BETTERNAS_CONTROL_PLANE_STATE_PATH", ".state/control-plane/state.json"),
|
||||
dbPath: env("BETTERNAS_CONTROL_PLANE_DB_PATH", ""),
|
||||
clientToken: env("BETTERNAS_CONTROL_PLANE_CLIENT_TOKEN", ""),
|
||||
nodeBootstrapToken: nodeBootstrapToken,
|
||||
davAuthSecret: davAuthSecret,
|
||||
davCredentialTTL: davCredentialTTL,
|
||||
sessionTTL: sessionTTL,
|
||||
registrationEnabled: env("BETTERNAS_REGISTRATION_ENABLED", "true") == "true",
|
||||
corsOrigin: env("BETTERNAS_CORS_ORIGIN", ""),
|
||||
},
|
||||
startedAt,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,10 @@ func (a *app) handler() http.Handler {
|
|||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /health", a.handleHealth)
|
||||
mux.HandleFunc("GET /version", a.handleVersion)
|
||||
mux.HandleFunc("POST /api/v1/auth/register", a.handleAuthRegister)
|
||||
mux.HandleFunc("POST /api/v1/auth/login", a.handleAuthLogin)
|
||||
mux.HandleFunc("POST /api/v1/auth/logout", a.handleAuthLogout)
|
||||
mux.HandleFunc("GET /api/v1/auth/me", a.handleAuthMe)
|
||||
mux.HandleFunc("POST /api/v1/nodes/register", a.handleNodeRegister)
|
||||
mux.HandleFunc("POST /api/v1/nodes/{nodeId}/heartbeat", a.handleNodeHeartbeat)
|
||||
mux.HandleFunc("PUT /api/v1/nodes/{nodeId}/exports", a.handleNodeExports)
|
||||
|
|
@ -40,7 +44,12 @@ func (a *app) handler() http.Handler {
|
|||
mux.HandleFunc("POST /api/v1/mount-profiles/issue", a.handleMountProfileIssue)
|
||||
mux.HandleFunc("POST /api/v1/cloud-profiles/issue", a.handleCloudProfileIssue)
|
||||
|
||||
return mux
|
||||
var handler http.Handler = mux
|
||||
if a.config.corsOrigin != "" {
|
||||
handler = corsMiddleware(a.config.corsOrigin, handler)
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func (a *app) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
|
|
@ -891,14 +900,161 @@ func writeJSON(w http.ResponseWriter, statusCode int, payload any) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- auth handlers ---
|
||||
|
||||
type authRegisterRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type authLoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type authLoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
User user `json:"user"`
|
||||
}
|
||||
|
||||
func (a *app) handleAuthRegister(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.config.registrationEnabled {
|
||||
http.Error(w, "registration is disabled", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var request authRegisterRequest
|
||||
if err := decodeJSON(w, r, &request); err != nil {
|
||||
writeDecodeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(request.Username)
|
||||
if len(username) < 3 || len(username) > 64 {
|
||||
http.Error(w, "username must be between 3 and 64 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(request.Password) < 8 {
|
||||
http.Error(w, "password must be at least 8 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := a.store.createUser(username, request.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, errUsernameTaken) {
|
||||
http.Error(w, err.Error(), http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionTTL := a.config.sessionTTL
|
||||
if sessionTTL <= 0 {
|
||||
sessionTTL = 720 * time.Hour
|
||||
}
|
||||
token, err := a.store.createSession(u.ID, sessionTTL)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, authLoginResponse{Token: token, User: u})
|
||||
}
|
||||
|
||||
func (a *app) handleAuthLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var request authLoginRequest
|
||||
if err := decodeJSON(w, r, &request); err != nil {
|
||||
writeDecodeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := a.store.authenticateUser(strings.TrimSpace(request.Username), request.Password)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionTTL := a.config.sessionTTL
|
||||
if sessionTTL <= 0 {
|
||||
sessionTTL = 720 * time.Hour
|
||||
}
|
||||
token, err := a.store.createSession(u.ID, sessionTTL)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authLoginResponse{Token: token, User: u})
|
||||
}
|
||||
|
||||
func (a *app) handleAuthLogout(w http.ResponseWriter, r *http.Request) {
|
||||
token, ok := bearerToken(r)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
_ = a.store.deleteSession(token)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *app) handleAuthMe(w http.ResponseWriter, r *http.Request) {
|
||||
token, ok := bearerToken(r)
|
||||
if !ok {
|
||||
writeUnauthorized(w)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := a.store.validateSession(token)
|
||||
if err != nil {
|
||||
writeUnauthorized(w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, u)
|
||||
}
|
||||
|
||||
// --- CORS ---
|
||||
|
||||
func corsMiddleware(allowedOrigin string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// --- client auth ---
|
||||
|
||||
func (a *app) requireClientAuth(w http.ResponseWriter, r *http.Request) bool {
|
||||
presentedToken, ok := bearerToken(r)
|
||||
if !ok || !secureStringEquals(a.config.clientToken, presentedToken) {
|
||||
if !ok {
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
// Session-based auth (SQLite).
|
||||
if _, err := a.store.validateSession(presentedToken); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fall back to static client token for backwards compatibility.
|
||||
if a.config.clientToken != "" && secureStringEquals(a.config.clientToken, presentedToken) {
|
||||
return true
|
||||
}
|
||||
|
||||
writeUnauthorized(w)
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *app) authorizeNodeRegistration(w http.ResponseWriter, r *http.Request, machineID string) bool {
|
||||
|
|
|
|||
598
apps/control-plane/cmd/control-plane/sqlite_store.go
Normal file
598
apps/control-plane/cmd/control-plane/sqlite_store.go
Normal file
|
|
@ -0,0 +1,598 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
errUsernameTaken = errors.New("username already taken")
|
||||
errInvalidLogin = errors.New("invalid username or password")
|
||||
errSessionExpired = errors.New("session expired or invalid")
|
||||
)
|
||||
|
||||
const sqliteSchema = `
|
||||
CREATE TABLE IF NOT EXISTS ordinals (
|
||||
name TEXT PRIMARY KEY,
|
||||
value INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
INSERT OR IGNORE INTO ordinals (name, value) VALUES ('node', 0), ('export', 0);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
id TEXT PRIMARY KEY,
|
||||
machine_id TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
agent_version TEXT NOT NULL DEFAULT '',
|
||||
status TEXT NOT NULL DEFAULT 'online',
|
||||
last_seen_at TEXT,
|
||||
direct_address TEXT,
|
||||
relay_address TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_tokens (
|
||||
node_id TEXT PRIMARY KEY REFERENCES nodes(id),
|
||||
token_hash TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS exports (
|
||||
id TEXT PRIMARY KEY,
|
||||
node_id TEXT NOT NULL REFERENCES nodes(id),
|
||||
label TEXT NOT NULL DEFAULT '',
|
||||
path TEXT NOT NULL,
|
||||
mount_path TEXT NOT NULL DEFAULT '',
|
||||
capacity_bytes INTEGER,
|
||||
UNIQUE(node_id, path)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS export_protocols (
|
||||
export_id TEXT NOT NULL REFERENCES exports(id) ON DELETE CASCADE,
|
||||
protocol TEXT NOT NULL,
|
||||
PRIMARY KEY (export_id, protocol)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS export_tags (
|
||||
export_id TEXT NOT NULL REFERENCES exports(id) ON DELETE CASCADE,
|
||||
tag TEXT NOT NULL,
|
||||
PRIMARY KEY (export_id, tag)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE COLLATE NOCASE,
|
||||
password_hash TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
token TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
expires_at TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
type sqliteStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
||||
dir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return nil, fmt.Errorf("create database directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(wal)&_pragma=foreign_keys(1)")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database %s: %w", dbPath, err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(sqliteSchema); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("initialize database schema: %w", err)
|
||||
}
|
||||
|
||||
return &sqliteStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) nextOrdinal(tx *sql.Tx, name string) (int, error) {
|
||||
var value int
|
||||
err := tx.QueryRow("UPDATE ordinals SET value = value + 1 WHERE name = ? RETURNING value", name).Scan(&value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("next ordinal %q: %w", name, err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func ordinalToNodeID(ordinal int) string {
|
||||
if ordinal == 1 {
|
||||
return "dev-node"
|
||||
}
|
||||
return fmt.Sprintf("dev-node-%d", ordinal)
|
||||
}
|
||||
|
||||
func ordinalToExportID(ordinal int) string {
|
||||
if ordinal == 1 {
|
||||
return "dev-export"
|
||||
}
|
||||
return fmt.Sprintf("dev-export-%d", ordinal)
|
||||
}
|
||||
|
||||
func (s *sqliteStore) registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Check if machine already registered.
|
||||
var nodeID string
|
||||
err = tx.QueryRow("SELECT id FROM nodes WHERE machine_id = ?", request.MachineID).Scan(&nodeID)
|
||||
if err == sql.ErrNoRows {
|
||||
ordinal, err := s.nextOrdinal(tx, "node")
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, err
|
||||
}
|
||||
nodeID = ordinalToNodeID(ordinal)
|
||||
} else if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("lookup node by machine_id: %w", err)
|
||||
}
|
||||
|
||||
// Upsert node.
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO nodes (id, machine_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address)
|
||||
VALUES (?, ?, ?, ?, 'online', ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
display_name = excluded.display_name,
|
||||
agent_version = excluded.agent_version,
|
||||
status = 'online',
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
direct_address = excluded.direct_address,
|
||||
relay_address = excluded.relay_address
|
||||
`, nodeID, request.MachineID, request.DisplayName, request.AgentVersion,
|
||||
registeredAt.UTC().Format(time.RFC3339),
|
||||
nullableString(request.DirectAddress), nullableString(request.RelayAddress))
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("upsert node: %w", err)
|
||||
}
|
||||
|
||||
// Issue token if none exists.
|
||||
var issuedNodeToken string
|
||||
var existingHash sql.NullString
|
||||
_ = tx.QueryRow("SELECT token_hash FROM node_tokens WHERE node_id = ?", nodeID).Scan(&existingHash)
|
||||
|
||||
if !existingHash.Valid || strings.TrimSpace(existingHash.String) == "" {
|
||||
nodeToken, err := newOpaqueToken()
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, err
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT OR REPLACE INTO node_tokens (node_id, token_hash) VALUES (?, ?)",
|
||||
nodeID, hashOpaqueToken(nodeToken))
|
||||
if err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("store node token: %w", err)
|
||||
}
|
||||
issuedNodeToken = nodeToken
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nodeRegistrationResult{}, fmt.Errorf("commit registration: %w", err)
|
||||
}
|
||||
|
||||
node, _ := s.nodeByID(nodeID)
|
||||
return nodeRegistrationResult{
|
||||
Node: node,
|
||||
IssuedNodeToken: issuedNodeToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error) {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Verify node exists.
|
||||
var exists bool
|
||||
err = tx.QueryRow("SELECT 1 FROM nodes WHERE id = ?", nodeID).Scan(&exists)
|
||||
if err != nil {
|
||||
return nil, errNodeNotFound
|
||||
}
|
||||
|
||||
// Collect current export IDs for this node (by path).
|
||||
currentExports := make(map[string]string) // path -> exportID
|
||||
rows, err := tx.Query("SELECT id, path FROM exports WHERE node_id = ?", nodeID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query current exports: %w", err)
|
||||
}
|
||||
for rows.Next() {
|
||||
var id, path string
|
||||
if err := rows.Scan(&id, &path); err != nil {
|
||||
rows.Close()
|
||||
return nil, fmt.Errorf("scan current export: %w", err)
|
||||
}
|
||||
currentExports[path] = id
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
keepPaths := make(map[string]struct{}, len(request.Exports))
|
||||
for _, input := range request.Exports {
|
||||
exportID, exists := currentExports[input.Path]
|
||||
if !exists {
|
||||
ordinal, err := s.nextOrdinal(tx, "export")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exportID = ordinalToExportID(ordinal)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO exports (id, node_id, label, path, mount_path, capacity_bytes)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
label = excluded.label,
|
||||
mount_path = excluded.mount_path,
|
||||
capacity_bytes = excluded.capacity_bytes
|
||||
`, exportID, nodeID, input.Label, input.Path, input.MountPath, nullableInt64(input.CapacityBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upsert export %q: %w", input.Path, err)
|
||||
}
|
||||
|
||||
// Replace protocols.
|
||||
if _, err := tx.Exec("DELETE FROM export_protocols WHERE export_id = ?", exportID); err != nil {
|
||||
return nil, fmt.Errorf("clear export protocols: %w", err)
|
||||
}
|
||||
for _, protocol := range input.Protocols {
|
||||
if _, err := tx.Exec("INSERT INTO export_protocols (export_id, protocol) VALUES (?, ?)", exportID, protocol); err != nil {
|
||||
return nil, fmt.Errorf("insert export protocol: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Replace tags.
|
||||
if _, err := tx.Exec("DELETE FROM export_tags WHERE export_id = ?", exportID); err != nil {
|
||||
return nil, fmt.Errorf("clear export tags: %w", err)
|
||||
}
|
||||
for _, tag := range input.Tags {
|
||||
if _, err := tx.Exec("INSERT INTO export_tags (export_id, tag) VALUES (?, ?)", exportID, tag); err != nil {
|
||||
return nil, fmt.Errorf("insert export tag: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
keepPaths[input.Path] = struct{}{}
|
||||
}
|
||||
|
||||
// Remove exports not in the input.
|
||||
for path, exportID := range currentExports {
|
||||
if _, keep := keepPaths[path]; !keep {
|
||||
if _, err := tx.Exec("DELETE FROM exports WHERE id = ?", exportID); err != nil {
|
||||
return nil, fmt.Errorf("delete stale export %q: %w", exportID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit exports: %w", err)
|
||||
}
|
||||
|
||||
return s.listExportsForNode(nodeID), nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error {
|
||||
result, err := s.db.Exec(
|
||||
"UPDATE nodes SET status = ?, last_seen_at = ? WHERE id = ?",
|
||||
request.Status, request.LastSeenAt, nodeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update heartbeat: %w", err)
|
||||
}
|
||||
affected, _ := result.RowsAffected()
|
||||
if affected == 0 {
|
||||
return errNodeNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) listExports() []storageExport {
|
||||
rows, err := s.db.Query("SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports ORDER BY id")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var exports []storageExport
|
||||
for rows.Next() {
|
||||
e := s.scanExport(rows)
|
||||
if e.ID != "" {
|
||||
exports = append(exports, e)
|
||||
}
|
||||
}
|
||||
if exports == nil {
|
||||
exports = []storageExport{}
|
||||
}
|
||||
|
||||
// Load protocols and tags for each export.
|
||||
for i := range exports {
|
||||
exports[i].Protocols = s.loadExportProtocols(exports[i].ID)
|
||||
exports[i].Tags = s.loadExportTags(exports[i].ID)
|
||||
}
|
||||
|
||||
return exports
|
||||
}
|
||||
|
||||
func (s *sqliteStore) listExportsForNode(nodeID string) []storageExport {
|
||||
rows, err := s.db.Query("SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports WHERE node_id = ? ORDER BY id", nodeID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var exports []storageExport
|
||||
for rows.Next() {
|
||||
e := s.scanExport(rows)
|
||||
if e.ID != "" {
|
||||
exports = append(exports, e)
|
||||
}
|
||||
}
|
||||
if exports == nil {
|
||||
exports = []storageExport{}
|
||||
}
|
||||
|
||||
for i := range exports {
|
||||
exports[i].Protocols = s.loadExportProtocols(exports[i].ID)
|
||||
exports[i].Tags = s.loadExportTags(exports[i].ID)
|
||||
}
|
||||
|
||||
sort.Slice(exports, func(i, j int) bool { return exports[i].ID < exports[j].ID })
|
||||
return exports
|
||||
}
|
||||
|
||||
func (s *sqliteStore) exportContext(exportID string) (exportContext, bool) {
|
||||
var e storageExport
|
||||
var capacityBytes sql.NullInt64
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, node_id, label, path, mount_path, capacity_bytes FROM exports WHERE id = ?",
|
||||
exportID).Scan(&e.ID, &e.NasNodeID, &e.Label, &e.Path, &e.MountPath, &capacityBytes)
|
||||
if err != nil {
|
||||
return exportContext{}, false
|
||||
}
|
||||
if capacityBytes.Valid {
|
||||
e.CapacityBytes = &capacityBytes.Int64
|
||||
}
|
||||
e.Protocols = s.loadExportProtocols(e.ID)
|
||||
e.Tags = s.loadExportTags(e.ID)
|
||||
|
||||
node, ok := s.nodeByID(e.NasNodeID)
|
||||
if !ok {
|
||||
return exportContext{}, false
|
||||
}
|
||||
|
||||
return exportContext{export: e, node: node}, true
|
||||
}
|
||||
|
||||
func (s *sqliteStore) nodeByID(nodeID string) (nasNode, bool) {
|
||||
var n nasNode
|
||||
var directAddr, relayAddr sql.NullString
|
||||
var lastSeenAt sql.NullString
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, machine_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE id = ?",
|
||||
nodeID).Scan(&n.ID, &n.MachineID, &n.DisplayName, &n.AgentVersion, &n.Status, &lastSeenAt, &directAddr, &relayAddr)
|
||||
if err != nil {
|
||||
return nasNode{}, false
|
||||
}
|
||||
if lastSeenAt.Valid {
|
||||
n.LastSeenAt = lastSeenAt.String
|
||||
}
|
||||
if directAddr.Valid {
|
||||
n.DirectAddress = &directAddr.String
|
||||
}
|
||||
if relayAddr.Valid {
|
||||
n.RelayAddress = &relayAddr.String
|
||||
}
|
||||
return n, true
|
||||
}
|
||||
|
||||
func (s *sqliteStore) nodeAuthByMachineID(machineID string) (nodeAuthState, bool) {
|
||||
var state nodeAuthState
|
||||
var tokenHash sql.NullString
|
||||
err := s.db.QueryRow(`
|
||||
SELECT n.id, nt.token_hash
|
||||
FROM nodes n
|
||||
LEFT JOIN node_tokens nt ON nt.node_id = n.id
|
||||
WHERE n.machine_id = ?
|
||||
`, machineID).Scan(&state.NodeID, &tokenHash)
|
||||
if err != nil {
|
||||
return nodeAuthState{}, false
|
||||
}
|
||||
if tokenHash.Valid {
|
||||
state.TokenHash = tokenHash.String
|
||||
}
|
||||
return state, true
|
||||
}
|
||||
|
||||
func (s *sqliteStore) nodeAuthByID(nodeID string) (nodeAuthState, bool) {
|
||||
var state nodeAuthState
|
||||
var tokenHash sql.NullString
|
||||
err := s.db.QueryRow(`
|
||||
SELECT n.id, nt.token_hash
|
||||
FROM nodes n
|
||||
LEFT JOIN node_tokens nt ON nt.node_id = n.id
|
||||
WHERE n.id = ?
|
||||
`, nodeID).Scan(&state.NodeID, &tokenHash)
|
||||
if err != nil {
|
||||
return nodeAuthState{}, false
|
||||
}
|
||||
if tokenHash.Valid {
|
||||
state.TokenHash = tokenHash.String
|
||||
}
|
||||
return state, true
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func (s *sqliteStore) scanExport(rows *sql.Rows) storageExport {
|
||||
var e storageExport
|
||||
var capacityBytes sql.NullInt64
|
||||
if err := rows.Scan(&e.ID, &e.NasNodeID, &e.Label, &e.Path, &e.MountPath, &capacityBytes); err != nil {
|
||||
return storageExport{}
|
||||
}
|
||||
if capacityBytes.Valid {
|
||||
e.CapacityBytes = &capacityBytes.Int64
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (s *sqliteStore) loadExportProtocols(exportID string) []string {
|
||||
rows, err := s.db.Query("SELECT protocol FROM export_protocols WHERE export_id = ? ORDER BY protocol", exportID)
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var protocols []string
|
||||
for rows.Next() {
|
||||
var p string
|
||||
if err := rows.Scan(&p); err == nil {
|
||||
protocols = append(protocols, p)
|
||||
}
|
||||
}
|
||||
if protocols == nil {
|
||||
return []string{}
|
||||
}
|
||||
return protocols
|
||||
}
|
||||
|
||||
func (s *sqliteStore) loadExportTags(exportID string) []string {
|
||||
rows, err := s.db.Query("SELECT tag FROM export_tags WHERE export_id = ? ORDER BY tag", exportID)
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tags []string
|
||||
for rows.Next() {
|
||||
var t string
|
||||
if err := rows.Scan(&t); err == nil {
|
||||
tags = append(tags, t)
|
||||
}
|
||||
}
|
||||
if tags == nil {
|
||||
return []string{}
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
func nullableString(p *string) sql.NullString {
|
||||
if p == nil {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: *p, Valid: true}
|
||||
}
|
||||
|
||||
func nullableInt64(p *int64) sql.NullInt64 {
|
||||
if p == nil {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
return sql.NullInt64{Int64: *p, Valid: true}
|
||||
}
|
||||
|
||||
// --- user auth ---
|
||||
|
||||
func (s *sqliteStore) createUser(username string, password string) (user, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return user{}, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
id, err := newSessionToken()
|
||||
if err != nil {
|
||||
return user{}, err
|
||||
}
|
||||
|
||||
var u user
|
||||
err = s.db.QueryRow(`
|
||||
INSERT INTO users (id, username, password_hash) VALUES (?, ?, ?)
|
||||
RETURNING id, username, created_at
|
||||
`, id, username, string(hash)).Scan(&u.ID, &u.Username, &u.CreatedAt)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return user{}, errUsernameTaken
|
||||
}
|
||||
return user{}, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) authenticateUser(username string, password string) (user, error) {
|
||||
var u user
|
||||
var passwordHash string
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
|
||||
username).Scan(&u.ID, &u.Username, &passwordHash, &u.CreatedAt)
|
||||
if err != nil {
|
||||
return user{}, errInvalidLogin
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
|
||||
return user{}, errInvalidLogin
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) createSession(userID string, ttl time.Duration) (string, error) {
|
||||
token, err := newSessionToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
expiresAt := time.Now().UTC().Add(ttl).Format(time.RFC3339)
|
||||
_, err = s.db.Exec(
|
||||
"INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
|
||||
token, userID, expiresAt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
// Clean up expired sessions opportunistically.
|
||||
_, _ = s.db.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().UTC().Format(time.RFC3339))
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) validateSession(token string) (user, error) {
|
||||
var u user
|
||||
err := s.db.QueryRow(`
|
||||
SELECT u.id, u.username, u.created_at
|
||||
FROM sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.token = ? AND s.expires_at > ?
|
||||
`, token, time.Now().UTC().Format(time.RFC3339)).Scan(&u.ID, &u.Username, &u.CreatedAt)
|
||||
if err != nil {
|
||||
return user{}, errSessionExpired
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) deleteSession(token string) error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE token = ?", token)
|
||||
return err
|
||||
}
|
||||
|
||||
func newSessionToken() (string, error) {
|
||||
raw := make([]byte, 32)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", fmt.Errorf("generate session token: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(raw), nil
|
||||
}
|
||||
297
apps/control-plane/cmd/control-plane/sqlite_store_test.go
Normal file
297
apps/control-plane/cmd/control-plane/sqlite_store_test.go
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestSQLiteApp(t *testing.T, config appConfig) (*app, *httptest.Server) {
|
||||
t.Helper()
|
||||
|
||||
if config.dbPath == "" {
|
||||
config.dbPath = filepath.Join(t.TempDir(), "test.db")
|
||||
}
|
||||
|
||||
if config.version == "" {
|
||||
config.version = "test-version"
|
||||
}
|
||||
if config.clientToken == "" {
|
||||
config.clientToken = testClientToken
|
||||
}
|
||||
if config.nodeBootstrapToken == "" {
|
||||
config.nodeBootstrapToken = testNodeBootstrapToken
|
||||
}
|
||||
if config.davAuthSecret == "" {
|
||||
config.davAuthSecret = "test-dav-auth-secret"
|
||||
}
|
||||
if config.davCredentialTTL == 0 {
|
||||
config.davCredentialTTL = time.Hour
|
||||
}
|
||||
|
||||
app, err := newApp(config, testControlPlaneNow)
|
||||
if err != nil {
|
||||
t.Fatalf("new app: %v", err)
|
||||
}
|
||||
app.now = func() time.Time { return testControlPlaneNow }
|
||||
|
||||
server := httptest.NewServer(app.handler())
|
||||
return app, server
|
||||
}
|
||||
|
||||
func TestSQLiteHealthAndVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
nextcloudBaseURL: "http://nextcloud.test",
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
health := getJSON[controlPlaneHealthResponse](t, server.Client(), server.URL+"/health")
|
||||
if health.Status != "ok" {
|
||||
t.Fatalf("expected status ok, got %q", health.Status)
|
||||
}
|
||||
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports")
|
||||
if len(exports) != 0 {
|
||||
t.Fatalf("expected no exports, got %d", len(exports))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLiteRegistrationAndExports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
nextcloudBaseURL: "http://nextcloud.test",
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
directAddress := "http://nas.local:8090"
|
||||
registration := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-1",
|
||||
DisplayName: "Primary NAS",
|
||||
AgentVersion: "1.2.3",
|
||||
DirectAddress: &directAddress,
|
||||
RelayAddress: nil,
|
||||
})
|
||||
if registration.NodeToken == "" {
|
||||
t.Fatal("expected node registration to return a node token")
|
||||
}
|
||||
if registration.Node.ID != "dev-node" {
|
||||
t.Fatalf("expected node ID %q, got %q", "dev-node", registration.Node.ID)
|
||||
}
|
||||
|
||||
syncedExports := syncNodeExports(t, server.Client(), registration.NodeToken, server.URL+"/api/v1/nodes/"+registration.Node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{
|
||||
{
|
||||
Label: "Docs",
|
||||
Path: "/srv/docs",
|
||||
MountPath: "/dav/docs/",
|
||||
Protocols: []string{"webdav"},
|
||||
CapacityBytes: nil,
|
||||
Tags: []string{"work"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if len(syncedExports) != 1 {
|
||||
t.Fatalf("expected 1 export, got %d", len(syncedExports))
|
||||
}
|
||||
if syncedExports[0].ID != "dev-export" {
|
||||
t.Fatalf("expected export ID %q, got %q", "dev-export", syncedExports[0].ID)
|
||||
}
|
||||
if syncedExports[0].Label != "Docs" {
|
||||
t.Fatalf("expected label %q, got %q", "Docs", syncedExports[0].Label)
|
||||
}
|
||||
|
||||
allExports := getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports")
|
||||
if len(allExports) != 1 {
|
||||
t.Fatalf("expected 1 export in list, got %d", len(allExports))
|
||||
}
|
||||
|
||||
mount := postJSONAuth[mountProfile](t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{ExportID: "dev-export"})
|
||||
if mount.MountURL != "http://nas.local:8090/dav/docs/" {
|
||||
t.Fatalf("expected mount URL %q, got %q", "http://nas.local:8090/dav/docs/", mount.MountURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLiteReRegistrationKeepsNodeID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{version: "test-version"})
|
||||
defer server.Close()
|
||||
|
||||
directAddress := "http://nas.local:8090"
|
||||
first := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-1",
|
||||
DisplayName: "NAS",
|
||||
AgentVersion: "1.0.0",
|
||||
DirectAddress: &directAddress,
|
||||
})
|
||||
|
||||
second := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", first.NodeToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-1",
|
||||
DisplayName: "NAS Updated",
|
||||
AgentVersion: "1.0.1",
|
||||
DirectAddress: &directAddress,
|
||||
})
|
||||
|
||||
if second.Node.ID != first.Node.ID {
|
||||
t.Fatalf("expected re-registration to keep node ID %q, got %q", first.Node.ID, second.Node.ID)
|
||||
}
|
||||
if second.NodeToken != "" {
|
||||
t.Fatalf("expected re-registration to not issue new token, got %q", second.NodeToken)
|
||||
}
|
||||
if second.Node.DisplayName != "NAS Updated" {
|
||||
t.Fatalf("expected updated display name, got %q", second.Node.DisplayName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLiteExportSyncRemovesStaleExports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{version: "test-version"})
|
||||
defer server.Close()
|
||||
|
||||
directAddress := "http://nas.local:8090"
|
||||
reg := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-stale",
|
||||
DisplayName: "NAS",
|
||||
AgentVersion: "1.0.0",
|
||||
DirectAddress: &directAddress,
|
||||
})
|
||||
|
||||
syncNodeExports(t, server.Client(), reg.NodeToken, server.URL+"/api/v1/nodes/"+reg.Node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{
|
||||
{Label: "A", Path: "/a", MountPath: "/dav/a/", Protocols: []string{"webdav"}, Tags: []string{}},
|
||||
{Label: "B", Path: "/b", MountPath: "/dav/b/", Protocols: []string{"webdav"}, Tags: []string{}},
|
||||
},
|
||||
})
|
||||
|
||||
exports := getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports")
|
||||
if len(exports) != 2 {
|
||||
t.Fatalf("expected 2 exports, got %d", len(exports))
|
||||
}
|
||||
|
||||
// Sync with only A - B should be removed.
|
||||
syncNodeExports(t, server.Client(), reg.NodeToken, server.URL+"/api/v1/nodes/"+reg.Node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{
|
||||
{Label: "A Updated", Path: "/a", MountPath: "/dav/a/", Protocols: []string{"webdav"}, Tags: []string{}},
|
||||
},
|
||||
})
|
||||
|
||||
exports = getJSONAuth[[]storageExport](t, server.Client(), testClientToken, server.URL+"/api/v1/exports")
|
||||
if len(exports) != 1 {
|
||||
t.Fatalf("expected 1 export after stale removal, got %d", len(exports))
|
||||
}
|
||||
if exports[0].Label != "A Updated" {
|
||||
t.Fatalf("expected updated label, got %q", exports[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLiteHeartbeat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, server := newTestSQLiteApp(t, appConfig{version: "test-version"})
|
||||
defer server.Close()
|
||||
_ = app
|
||||
|
||||
directAddress := "http://nas.local:8090"
|
||||
reg := registerNode(t, server.Client(), server.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-hb",
|
||||
DisplayName: "NAS",
|
||||
AgentVersion: "1.0.0",
|
||||
DirectAddress: &directAddress,
|
||||
})
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), reg.NodeToken, server.URL+"/api/v1/nodes/"+reg.Node.ID+"/heartbeat", nodeHeartbeatRequest{
|
||||
NodeID: reg.Node.ID,
|
||||
Status: "online",
|
||||
LastSeenAt: "2025-06-01T12:00:00Z",
|
||||
}, http.StatusNoContent)
|
||||
|
||||
node, ok := app.store.nodeByID(reg.Node.ID)
|
||||
if !ok {
|
||||
t.Fatal("expected node to exist after heartbeat")
|
||||
}
|
||||
if node.LastSeenAt != "2025-06-01T12:00:00Z" {
|
||||
t.Fatalf("expected updated lastSeenAt, got %q", node.LastSeenAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLitePersistsAcrossRestart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "persist.db")
|
||||
directAddress := "http://nas.local:8090"
|
||||
|
||||
_, firstServer := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
dbPath: dbPath,
|
||||
})
|
||||
registration := registerNode(t, firstServer.Client(), firstServer.URL+"/api/v1/nodes/register", testNodeBootstrapToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-persist",
|
||||
DisplayName: "Persisted NAS",
|
||||
AgentVersion: "1.2.3",
|
||||
DirectAddress: &directAddress,
|
||||
})
|
||||
syncNodeExports(t, firstServer.Client(), registration.NodeToken, firstServer.URL+"/api/v1/nodes/"+registration.Node.ID+"/exports", nodeExportsRequest{
|
||||
Exports: []storageExportInput{{
|
||||
Label: "Docs",
|
||||
Path: "/srv/docs",
|
||||
MountPath: "/dav/persisted/",
|
||||
Protocols: []string{"webdav"},
|
||||
Tags: []string{"work"},
|
||||
}},
|
||||
})
|
||||
firstServer.Close()
|
||||
|
||||
// Restart with same DB path.
|
||||
_, secondServer := newTestSQLiteApp(t, appConfig{
|
||||
version: "test-version",
|
||||
dbPath: dbPath,
|
||||
})
|
||||
defer secondServer.Close()
|
||||
|
||||
exports := getJSONAuth[[]storageExport](t, secondServer.Client(), testClientToken, secondServer.URL+"/api/v1/exports")
|
||||
if len(exports) != 1 {
|
||||
t.Fatalf("expected persisted export after restart, got %d", len(exports))
|
||||
}
|
||||
if exports[0].ID != "dev-export" {
|
||||
t.Fatalf("expected persisted export ID %q, got %q", "dev-export", exports[0].ID)
|
||||
}
|
||||
if exports[0].MountPath != "/dav/persisted/" {
|
||||
t.Fatalf("expected persisted mountPath %q, got %q", "/dav/persisted/", exports[0].MountPath)
|
||||
}
|
||||
if len(exports[0].Tags) != 1 || exports[0].Tags[0] != "work" {
|
||||
t.Fatalf("expected persisted tags [work], got %v", exports[0].Tags)
|
||||
}
|
||||
|
||||
// Re-register with the original node token.
|
||||
reReg := registerNode(t, secondServer.Client(), secondServer.URL+"/api/v1/nodes/register", registration.NodeToken, nodeRegistrationRequest{
|
||||
MachineID: "machine-persist",
|
||||
DisplayName: "Persisted NAS Updated",
|
||||
AgentVersion: "1.2.4",
|
||||
DirectAddress: &directAddress,
|
||||
})
|
||||
if reReg.Node.ID != registration.Node.ID {
|
||||
t.Fatalf("expected persisted node ID %q, got %q", registration.Node.ID, reReg.Node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLiteAuthEnforcement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, server := newTestSQLiteApp(t, appConfig{version: "test-version"})
|
||||
defer server.Close()
|
||||
|
||||
getStatusWithAuth(t, server.Client(), "", server.URL+"/api/v1/exports", http.StatusUnauthorized)
|
||||
getStatusWithAuth(t, server.Client(), "wrong-token", server.URL+"/api/v1/exports", http.StatusUnauthorized)
|
||||
|
||||
postJSONAuthStatus(t, server.Client(), testClientToken, server.URL+"/api/v1/mount-profiles/issue", mountProfileRequest{
|
||||
ExportID: "missing-export",
|
||||
}, http.StatusNotFound)
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -483,6 +484,30 @@ func copyStorageExport(export storageExport) storageExport {
|
|||
}
|
||||
}
|
||||
|
||||
// --- user auth stubs (memory store does not support user auth) ---
|
||||
|
||||
var errAuthNotSupported = errors.New("user auth requires SQLite database (set BETTERNAS_CONTROL_PLANE_DB_PATH)")
|
||||
|
||||
func (s *memoryStore) createUser(_ string, _ string) (user, error) {
|
||||
return user{}, errAuthNotSupported
|
||||
}
|
||||
|
||||
func (s *memoryStore) authenticateUser(_ string, _ string) (user, error) {
|
||||
return user{}, errAuthNotSupported
|
||||
}
|
||||
|
||||
func (s *memoryStore) createSession(_ string, _ time.Duration) (string, error) {
|
||||
return "", errAuthNotSupported
|
||||
}
|
||||
|
||||
func (s *memoryStore) validateSession(_ string) (user, error) {
|
||||
return user{}, errAuthNotSupported
|
||||
}
|
||||
|
||||
func (s *memoryStore) deleteSession(_ string) error {
|
||||
return errAuthNotSupported
|
||||
}
|
||||
|
||||
func newOpaqueToken() (string, error) {
|
||||
raw := make([]byte, 32)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
|
|
|
|||
23
apps/control-plane/cmd/control-plane/store_iface.go
Normal file
23
apps/control-plane/cmd/control-plane/store_iface.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// store defines the persistence interface for the control-plane.
|
||||
type store interface {
|
||||
// Node management
|
||||
registerNode(request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error)
|
||||
upsertExports(nodeID string, request nodeExportsRequest) ([]storageExport, error)
|
||||
recordHeartbeat(nodeID string, request nodeHeartbeatRequest) error
|
||||
listExports() []storageExport
|
||||
exportContext(exportID string) (exportContext, bool)
|
||||
nodeByID(nodeID string) (nasNode, bool)
|
||||
nodeAuthByMachineID(machineID string) (nodeAuthState, bool)
|
||||
nodeAuthByID(nodeID string) (nodeAuthState, bool)
|
||||
|
||||
// User auth
|
||||
createUser(username string, password string) (user, error)
|
||||
authenticateUser(username string, password string) (user, error)
|
||||
createSession(userID string, ttl time.Duration) (string, error)
|
||||
validateSession(token string) (user, error)
|
||||
deleteSession(token string) error
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue