mirror of
https://github.com/harivansh-afk/betterNAS.git
synced 2026-04-18 03:00:38 +00:00
Stabilize the node agent runtime loop.
Keep the NAS-side runtime bounded to the configured export path, make WebDAV and registration behavior env-driven, and add runtime coverage so the first storage loop can be verified locally. Generated with [Devin](https://cli.devin.ai/docs) Co-Authored-By: Devin <158243242+devin-ai-integration[bot]@users.noreply.github.com>
This commit is contained in:
parent
a7f85f4871
commit
273af4b0ab
14 changed files with 3294 additions and 36 deletions
190
apps/node-agent/internal/nodeagent/app.go
Normal file
190
apps/node-agent/internal/nodeagent/app.go
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/webdav"
|
||||
)
|
||||
|
||||
const davPrefix = "/dav/"
|
||||
|
||||
type App struct {
|
||||
cfg Config
|
||||
davFS *exportFileSystem
|
||||
logger *log.Logger
|
||||
server *http.Server
|
||||
registration *registrationLoop
|
||||
}
|
||||
|
||||
func New(cfg Config, logger *log.Logger) (*App, error) {
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
|
||||
if err := validateRuntimeConfig(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ensureExportPath(cfg.ExportPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
davFS, err := newExportFileSystem(cfg.ExportPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app := &App{
|
||||
cfg: cfg,
|
||||
davFS: davFS,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", app.handleHealth)
|
||||
mux.HandleFunc("/dav", handleDAVRedirect)
|
||||
mux.Handle(davPrefix, http.Handler(&webdav.Handler{
|
||||
Prefix: davPrefix,
|
||||
FileSystem: app.davFS,
|
||||
LockSystem: webdav.NewMemLS(),
|
||||
}))
|
||||
mux.HandleFunc("/", http.NotFound)
|
||||
|
||||
app.server = &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
if cfg.RegisterEnabled {
|
||||
app.registration = newRegistrationLoop(cfg, logger)
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
func (a *App) ListenAndServe(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", a.cfg.ListenAddress)
|
||||
if err != nil {
|
||||
a.closeDAVFS()
|
||||
return fmt.Errorf("listen on %s: %w", a.cfg.ListenAddress, err)
|
||||
}
|
||||
|
||||
a.logger.Printf("betterNAS node agent serving %s at %s on %s", a.cfg.ExportPath, davPrefix, listener.Addr())
|
||||
if strings.TrimSpace(a.cfg.ListenAddress) == defaultListenAddress(a.cfg.Port) {
|
||||
a.logger.Printf("betterNAS node agent using loopback-only listen address %s by default", a.cfg.ListenAddress)
|
||||
}
|
||||
if a.registration != nil {
|
||||
a.logger.Printf("betterNAS node agent control-plane sync enabled for %s", a.cfg.ControlPlaneURL)
|
||||
if strings.TrimSpace(a.cfg.DirectAddress) == "" {
|
||||
a.logger.Printf("betterNAS node agent is not advertising a direct address; set BETTERNAS_NODE_DIRECT_ADDRESS if clients should mount this listener directly")
|
||||
}
|
||||
}
|
||||
|
||||
return a.Serve(ctx, listener)
|
||||
}
|
||||
|
||||
func (a *App) Serve(ctx context.Context, listener net.Listener) error {
|
||||
defer a.closeDAVFS()
|
||||
|
||||
serverErrors := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
serverErrors <- a.server.Serve(listener)
|
||||
}()
|
||||
|
||||
if a.registration != nil {
|
||||
go a.registration.Run(ctx)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverErrors:
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := a.server.Shutdown(shutdownCtx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("shutdown node-agent server: %w", err)
|
||||
}
|
||||
|
||||
err := <-serverErrors
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *App) closeDAVFS() {
|
||||
if a.davFS == nil {
|
||||
return
|
||||
}
|
||||
|
||||
davFS := a.davFS
|
||||
a.davFS = nil
|
||||
|
||||
if err := davFS.Close(); err != nil {
|
||||
a.logger.Printf("betterNAS node agent failed to close export root %s: %v", a.cfg.ExportPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
w.Header().Set("Allow", "GET, HEAD")
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if r.Method != http.MethodHead {
|
||||
_, _ = io.WriteString(w, "ok\n")
|
||||
}
|
||||
}
|
||||
|
||||
func handleDAVRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
location := davPrefix
|
||||
if rawQuery := strings.TrimSpace(r.URL.RawQuery); rawQuery != "" {
|
||||
location += "?" + rawQuery
|
||||
}
|
||||
|
||||
w.Header().Set("Location", location)
|
||||
w.WriteHeader(http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func ensureExportPath(exportPath string) error {
|
||||
trimmedPath := strings.TrimSpace(exportPath)
|
||||
if trimmedPath == "" {
|
||||
return fmt.Errorf("export path is required")
|
||||
}
|
||||
|
||||
info, err := os.Stat(trimmedPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("export path %s does not exist", trimmedPath)
|
||||
}
|
||||
|
||||
return fmt.Errorf("stat export path %s: %w", trimmedPath, err)
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("export path %s is not a directory", trimmedPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
801
apps/node-agent/internal/nodeagent/app_integration_test.go
Normal file
801
apps/node-agent/internal/nodeagent/app_integration_test.go
Normal file
|
|
@ -0,0 +1,801 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const testControlPlaneToken = "test-control-plane-token"
|
||||
|
||||
func TestAppServesWebDAVFromConfiguredExportPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
|
||||
baseURL, stop := startTestApp(t, Config{
|
||||
Port: "0",
|
||||
ExportPath: exportPath,
|
||||
MachineID: "nas-1",
|
||||
DisplayName: "NAS 1",
|
||||
AgentVersion: "test-version",
|
||||
ExportLabel: "integration",
|
||||
ExportTags: []string{"finder"},
|
||||
HeartbeatInterval: time.Second,
|
||||
})
|
||||
defer stop()
|
||||
|
||||
healthResponse, err := http.Get(baseURL + "/health")
|
||||
if err != nil {
|
||||
t.Fatalf("get health: %v", err)
|
||||
}
|
||||
defer healthResponse.Body.Close()
|
||||
|
||||
healthBody, err := io.ReadAll(healthResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read health body: %v", err)
|
||||
}
|
||||
|
||||
if healthResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("health status = %d, want 200", healthResponse.StatusCode)
|
||||
}
|
||||
|
||||
if string(healthBody) != "ok\n" {
|
||||
t.Fatalf("health body = %q, want ok", string(healthBody))
|
||||
}
|
||||
|
||||
headRequest, err := http.NewRequest(http.MethodHead, baseURL+"/health", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build health head request: %v", err)
|
||||
}
|
||||
|
||||
headResponse, err := http.DefaultClient.Do(headRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("head health: %v", err)
|
||||
}
|
||||
defer headResponse.Body.Close()
|
||||
|
||||
headBody, err := io.ReadAll(headResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read head response body: %v", err)
|
||||
}
|
||||
|
||||
if len(headBody) != 0 {
|
||||
t.Fatalf("head body length = %d, want 0", len(headBody))
|
||||
}
|
||||
|
||||
redirectClient := &http.Client{
|
||||
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
redirectResponse, err := redirectClient.Get(baseURL + "/dav?session=abc")
|
||||
if err != nil {
|
||||
t.Fatalf("get /dav: %v", err)
|
||||
}
|
||||
defer redirectResponse.Body.Close()
|
||||
|
||||
if redirectResponse.StatusCode != http.StatusPermanentRedirect {
|
||||
t.Fatalf("redirect status = %d, want 308", redirectResponse.StatusCode)
|
||||
}
|
||||
|
||||
if redirectResponse.Header.Get("Location") != davPrefix+"?session=abc" {
|
||||
t.Fatalf("redirect location = %q, want %q", redirectResponse.Header.Get("Location"), davPrefix+"?session=abc")
|
||||
}
|
||||
|
||||
optionsRequest := mustRequest(t, http.MethodOptions, baseURL+davPrefix, nil)
|
||||
optionsResponse, err := http.DefaultClient.Do(optionsRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("options /dav/: %v", err)
|
||||
}
|
||||
defer optionsResponse.Body.Close()
|
||||
|
||||
if optionsResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("options status = %d, want 200", optionsResponse.StatusCode)
|
||||
}
|
||||
|
||||
if !strings.Contains(optionsResponse.Header.Get("Dav"), "1") {
|
||||
t.Fatalf("dav header = %q, want DAV support", optionsResponse.Header.Get("Dav"))
|
||||
}
|
||||
|
||||
putRequest := mustRequest(t, http.MethodPut, baseURL+"/dav/notes.txt", strings.NewReader("hello from webdav"))
|
||||
putResponse, err := http.DefaultClient.Do(putRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("put file: %v", err)
|
||||
}
|
||||
defer putResponse.Body.Close()
|
||||
|
||||
if putResponse.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("put status = %d, want 201", putResponse.StatusCode)
|
||||
}
|
||||
|
||||
savedBytes, err := os.ReadFile(filepath.Join(exportPath, "notes.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read saved file: %v", err)
|
||||
}
|
||||
|
||||
if string(savedBytes) != "hello from webdav" {
|
||||
t.Fatalf("saved file = %q, want file content", string(savedBytes))
|
||||
}
|
||||
|
||||
mkcolRequest, err := http.NewRequest("MKCOL", baseURL+"/dav/docs", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build mkcol request: %v", err)
|
||||
}
|
||||
|
||||
mkcolResponse, err := http.DefaultClient.Do(mkcolRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("mkcol docs: %v", err)
|
||||
}
|
||||
defer mkcolResponse.Body.Close()
|
||||
|
||||
if mkcolResponse.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("mkcol status = %d, want 201", mkcolResponse.StatusCode)
|
||||
}
|
||||
|
||||
propfindRequest, err := http.NewRequest("PROPFIND", baseURL+"/dav/docs", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build propfind request: %v", err)
|
||||
}
|
||||
propfindRequest.Header.Set("Depth", "0")
|
||||
|
||||
propfindResponse, err := http.DefaultClient.Do(propfindRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("propfind docs: %v", err)
|
||||
}
|
||||
defer propfindResponse.Body.Close()
|
||||
|
||||
propfindBody, err := io.ReadAll(propfindResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read propfind body: %v", err)
|
||||
}
|
||||
|
||||
if propfindResponse.StatusCode != http.StatusMultiStatus {
|
||||
t.Fatalf("propfind status = %d, want 207", propfindResponse.StatusCode)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(propfindBody), "<D:href>/dav/docs/</D:href>") {
|
||||
t.Fatalf("propfind body = %q, want docs href", string(propfindBody))
|
||||
}
|
||||
|
||||
getResponse, err := doWebDAVRequest(baseURL, http.MethodGet, "/dav/notes.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get file: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
getBody, err := io.ReadAll(getResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read get body: %v", err)
|
||||
}
|
||||
|
||||
if getResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("get file status = %d, want 200", getResponse.StatusCode)
|
||||
}
|
||||
|
||||
if string(getBody) != "hello from webdav" {
|
||||
t.Fatalf("get file body = %q, want file content", string(getBody))
|
||||
}
|
||||
|
||||
deleteRequest := mustRequest(t, http.MethodDelete, baseURL+"/dav/notes.txt", nil)
|
||||
deleteResponse, err := http.DefaultClient.Do(deleteRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("delete file: %v", err)
|
||||
}
|
||||
defer deleteResponse.Body.Close()
|
||||
|
||||
if deleteResponse.StatusCode != http.StatusNoContent {
|
||||
t.Fatalf("delete status = %d, want 204", deleteResponse.StatusCode)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(exportPath, "notes.txt")); !os.IsNotExist(err) {
|
||||
t.Fatalf("deleted file still exists or stat failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppServesSymlinksThatStayWithinExportRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export dir: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(exportPath, "plain.txt"), []byte("inside export"), 0o644); err != nil {
|
||||
t.Fatalf("write export file: %v", err)
|
||||
}
|
||||
|
||||
if err := os.Symlink("plain.txt", filepath.Join(exportPath, "alias.txt")); err != nil {
|
||||
t.Skipf("symlink creation unavailable: %v", err)
|
||||
}
|
||||
|
||||
baseURL, stop := startTestApp(t, Config{
|
||||
Port: "0",
|
||||
ExportPath: exportPath,
|
||||
})
|
||||
defer stop()
|
||||
|
||||
response, err := doWebDAVRequest(baseURL, http.MethodGet, "/dav/alias.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get symlinked file: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read symlinked file body: %v", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
t.Fatalf("symlink get status = %d, want 200", response.StatusCode)
|
||||
}
|
||||
|
||||
if string(body) != "inside export" {
|
||||
t.Fatalf("symlink get body = %q, want inside export", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRejectsSymlinksThatEscapeExportRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
exportPath := filepath.Join(tempDir, "export")
|
||||
outsidePath := filepath.Join(tempDir, "outside.txt")
|
||||
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export dir: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outsidePath, []byte("outside"), 0o644); err != nil {
|
||||
t.Fatalf("write outside file: %v", err)
|
||||
}
|
||||
|
||||
if err := os.Symlink("../outside.txt", filepath.Join(exportPath, "escape.txt")); err != nil {
|
||||
t.Skipf("symlink creation unavailable: %v", err)
|
||||
}
|
||||
|
||||
baseURL, stop := startTestApp(t, Config{
|
||||
Port: "0",
|
||||
ExportPath: exportPath,
|
||||
})
|
||||
defer stop()
|
||||
|
||||
getResponse, err := doWebDAVRequest(baseURL, http.MethodGet, "/dav/escape.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get escaped symlink: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
if getResponse.StatusCode < http.StatusBadRequest {
|
||||
t.Fatalf("escaped symlink get status = %d, want 4xx or 5xx", getResponse.StatusCode)
|
||||
}
|
||||
|
||||
putRequest := mustRequest(t, http.MethodPut, baseURL+"/dav/escape.txt", strings.NewReader("should-not-write"))
|
||||
putResponse, err := http.DefaultClient.Do(putRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("put escaped symlink: %v", err)
|
||||
}
|
||||
defer putResponse.Body.Close()
|
||||
|
||||
if putResponse.StatusCode < http.StatusBadRequest {
|
||||
t.Fatalf("escaped symlink put status = %d, want 4xx or 5xx", putResponse.StatusCode)
|
||||
}
|
||||
|
||||
outsideBytes, err := os.ReadFile(outsidePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read outside file: %v", err)
|
||||
}
|
||||
|
||||
if string(outsideBytes) != "outside" {
|
||||
t.Fatalf("outside file = %q, want unchanged content", string(outsideBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRegistersAndHeartbeatsAgainstControlPlane(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerRequests := make(chan nodeRegistrationRequest, 1)
|
||||
heartbeatRequests := make(chan nodeHeartbeatRequest, 4)
|
||||
|
||||
controlPlane := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer "+testControlPlaneToken {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
t.Errorf("authorization header = %q, want Bearer token", got)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.EscapedPath() {
|
||||
case registerNodeRoute:
|
||||
var request nodeRegistrationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode register request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
registerRequests <- request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"node/123"}`)
|
||||
case heartbeatRoute("node/123"):
|
||||
var request nodeHeartbeatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode heartbeat request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
heartbeatRequests <- request
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer controlPlane.Close()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
Port: "0",
|
||||
ListenAddress: listener.Addr().String(),
|
||||
ExportPath: exportPath,
|
||||
MachineID: "nas-42",
|
||||
DisplayName: "Garage NAS",
|
||||
AgentVersion: "test-version",
|
||||
DirectAddress: "http://" + listener.Addr().String(),
|
||||
ExportLabel: "archive",
|
||||
ExportTags: []string{"photos", "finder"},
|
||||
ControlPlaneURL: controlPlane.URL,
|
||||
ControlPlaneToken: testControlPlaneToken,
|
||||
RegisterEnabled: true,
|
||||
HeartbeatEnabled: true,
|
||||
HeartbeatInterval: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
stop := serveWithListener(t, listener, cfg)
|
||||
defer stop()
|
||||
|
||||
registerRequest := awaitValue(t, registerRequests, 2*time.Second, "register request")
|
||||
if registerRequest.MachineID != cfg.MachineID {
|
||||
t.Fatalf("machine id = %q, want %q", registerRequest.MachineID, cfg.MachineID)
|
||||
}
|
||||
|
||||
if registerRequest.DisplayName != cfg.DisplayName {
|
||||
t.Fatalf("display name = %q, want %q", registerRequest.DisplayName, cfg.DisplayName)
|
||||
}
|
||||
|
||||
if registerRequest.AgentVersion != cfg.AgentVersion {
|
||||
t.Fatalf("agent version = %q, want %q", registerRequest.AgentVersion, cfg.AgentVersion)
|
||||
}
|
||||
|
||||
if registerRequest.DirectAddress == nil || *registerRequest.DirectAddress != cfg.DirectAddress {
|
||||
t.Fatalf("direct address = %#v, want %q", registerRequest.DirectAddress, cfg.DirectAddress)
|
||||
}
|
||||
|
||||
if registerRequest.RelayAddress != nil {
|
||||
t.Fatalf("relay address = %#v, want nil", registerRequest.RelayAddress)
|
||||
}
|
||||
|
||||
if len(registerRequest.Exports) != 1 {
|
||||
t.Fatalf("exports length = %d, want 1", len(registerRequest.Exports))
|
||||
}
|
||||
|
||||
export := registerRequest.Exports[0]
|
||||
if export.Label != cfg.ExportLabel {
|
||||
t.Fatalf("export label = %q, want %q", export.Label, cfg.ExportLabel)
|
||||
}
|
||||
|
||||
if export.Path != cfg.ExportPath {
|
||||
t.Fatalf("export path = %q, want %q", export.Path, cfg.ExportPath)
|
||||
}
|
||||
|
||||
if len(export.Protocols) != 1 || export.Protocols[0] != "webdav" {
|
||||
t.Fatalf("export protocols = %#v, want [webdav]", export.Protocols)
|
||||
}
|
||||
|
||||
if len(export.Tags) != 2 || export.Tags[0] != "photos" || export.Tags[1] != "finder" {
|
||||
t.Fatalf("export tags = %#v, want [photos finder]", export.Tags)
|
||||
}
|
||||
|
||||
heartbeatRequest := awaitValue(t, heartbeatRequests, 2*time.Second, "heartbeat request")
|
||||
if heartbeatRequest.NodeID != "node/123" {
|
||||
t.Fatalf("heartbeat node id = %q, want node/123", heartbeatRequest.NodeID)
|
||||
}
|
||||
|
||||
if heartbeatRequest.Status != "online" {
|
||||
t.Fatalf("heartbeat status = %q, want online", heartbeatRequest.Status)
|
||||
}
|
||||
|
||||
if heartbeatRequest.LastSeenAt == "" {
|
||||
t.Fatal("heartbeat lastSeenAt is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRegistersWithoutControlPlaneTokenWhenUnset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerRequests := make(chan nodeRegistrationRequest, 1)
|
||||
|
||||
controlPlane := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "" {
|
||||
http.Error(w, "unexpected authorization header", http.StatusBadRequest)
|
||||
t.Errorf("authorization header = %q, want empty", got)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.EscapedPath() != registerNodeRoute {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var request nodeRegistrationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode register request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
registerRequests <- request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"node/no-token"}`)
|
||||
}))
|
||||
defer controlPlane.Close()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
|
||||
_, stop := startTestApp(t, Config{
|
||||
Port: "0",
|
||||
ExportPath: exportPath,
|
||||
MachineID: "nas-no-token",
|
||||
DisplayName: "No Token NAS",
|
||||
AgentVersion: "test-version",
|
||||
ExportLabel: "register-only",
|
||||
ControlPlaneURL: controlPlane.URL,
|
||||
RegisterEnabled: true,
|
||||
})
|
||||
defer stop()
|
||||
|
||||
registerRequest := awaitValue(t, registerRequests, 2*time.Second, "register request")
|
||||
if registerRequest.MachineID != "nas-no-token" {
|
||||
t.Fatalf("machine id = %q, want nas-no-token", registerRequest.MachineID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatRejectedNodeReregistersAndRecovers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerRequests := make(chan nodeRegistrationRequest, 4)
|
||||
heartbeatRequests := make(chan nodeHeartbeatRequest, 4)
|
||||
registerCount := 0
|
||||
|
||||
controlPlane := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer "+testControlPlaneToken {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
t.Errorf("authorization header = %q, want Bearer token", got)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.EscapedPath() {
|
||||
case registerNodeRoute:
|
||||
registerCount++
|
||||
|
||||
var request nodeRegistrationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode register request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
registerRequests <- request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if registerCount == 1 {
|
||||
_, _ = io.WriteString(w, `{"id":"node/stale"}`)
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = io.WriteString(w, `{"id":"node/fresh"}`)
|
||||
case heartbeatRoute("node/stale"), heartbeatRoute("node/fresh"):
|
||||
var request nodeHeartbeatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode heartbeat request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
heartbeatRequests <- request
|
||||
if r.URL.EscapedPath() == heartbeatRoute("node/stale") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer controlPlane.Close()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
|
||||
baseURL, stop := startTestApp(t, Config{
|
||||
Port: "0",
|
||||
ExportPath: exportPath,
|
||||
MachineID: "nas-stale",
|
||||
DisplayName: "NAS stale",
|
||||
AgentVersion: "test-version",
|
||||
ExportLabel: "resilient",
|
||||
ControlPlaneURL: controlPlane.URL,
|
||||
ControlPlaneToken: testControlPlaneToken,
|
||||
RegisterEnabled: true,
|
||||
HeartbeatEnabled: true,
|
||||
HeartbeatInterval: 50 * time.Millisecond,
|
||||
})
|
||||
defer stop()
|
||||
|
||||
firstRegister := awaitValue(t, registerRequests, 2*time.Second, "first register request")
|
||||
if firstRegister.MachineID != "nas-stale" {
|
||||
t.Fatalf("first register machine id = %q, want nas-stale", firstRegister.MachineID)
|
||||
}
|
||||
|
||||
secondRegister := awaitValue(t, registerRequests, 2*time.Second, "second register request")
|
||||
if secondRegister.MachineID != "nas-stale" {
|
||||
t.Fatalf("second register machine id = %q, want nas-stale", secondRegister.MachineID)
|
||||
}
|
||||
|
||||
firstHeartbeat := awaitValue(t, heartbeatRequests, 2*time.Second, "stale heartbeat request")
|
||||
if firstHeartbeat.NodeID != "node/stale" {
|
||||
t.Fatalf("stale heartbeat node id = %q, want node/stale", firstHeartbeat.NodeID)
|
||||
}
|
||||
|
||||
secondHeartbeat := awaitValue(t, heartbeatRequests, 2*time.Second, "fresh heartbeat request")
|
||||
if secondHeartbeat.NodeID != "node/fresh" {
|
||||
t.Fatalf("fresh heartbeat node id = %q, want node/fresh", secondHeartbeat.NodeID)
|
||||
}
|
||||
|
||||
propfindRequest, err := http.NewRequest("PROPFIND", baseURL+"/dav/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build WebDAV root propfind after heartbeat recovery: %v", err)
|
||||
}
|
||||
propfindRequest.Header.Set("Depth", "0")
|
||||
|
||||
propfindResponse, err := http.DefaultClient.Do(propfindRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("propfind WebDAV root after heartbeat recovery: %v", err)
|
||||
}
|
||||
defer propfindResponse.Body.Close()
|
||||
|
||||
if propfindResponse.StatusCode != http.StatusMultiStatus {
|
||||
t.Fatalf("propfind WebDAV root status after heartbeat recovery = %d, want 207", propfindResponse.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatRouteUnavailableStopsAfterFreshReregistrationWithoutStoppingWebDAV(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerRequests := make(chan nodeRegistrationRequest, 4)
|
||||
heartbeatAttempts := make(chan nodeHeartbeatRequest, 4)
|
||||
|
||||
controlPlane := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer "+testControlPlaneToken {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
t.Errorf("authorization header = %q, want Bearer token", got)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.EscapedPath() {
|
||||
case registerNodeRoute:
|
||||
var request nodeRegistrationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode register request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
registerRequests <- request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"node/404"}`)
|
||||
case heartbeatRoute("node/404"):
|
||||
var request nodeHeartbeatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode heartbeat request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
heartbeatAttempts <- request
|
||||
http.NotFound(w, r)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer controlPlane.Close()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
|
||||
baseURL, stop := startTestApp(t, Config{
|
||||
Port: "0",
|
||||
ExportPath: exportPath,
|
||||
MachineID: "nas-404",
|
||||
DisplayName: "NAS 404",
|
||||
AgentVersion: "test-version",
|
||||
ExportLabel: "resilient",
|
||||
ControlPlaneURL: controlPlane.URL,
|
||||
ControlPlaneToken: testControlPlaneToken,
|
||||
RegisterEnabled: true,
|
||||
HeartbeatEnabled: true,
|
||||
HeartbeatInterval: 50 * time.Millisecond,
|
||||
})
|
||||
defer stop()
|
||||
|
||||
firstRegister := awaitValue(t, registerRequests, 2*time.Second, "first register request")
|
||||
if firstRegister.MachineID != "nas-404" {
|
||||
t.Fatalf("first register machine id = %q, want nas-404", firstRegister.MachineID)
|
||||
}
|
||||
|
||||
secondRegister := awaitValue(t, registerRequests, 2*time.Second, "second register request")
|
||||
if secondRegister.MachineID != "nas-404" {
|
||||
t.Fatalf("second register machine id = %q, want nas-404", secondRegister.MachineID)
|
||||
}
|
||||
|
||||
firstHeartbeat := awaitValue(t, heartbeatAttempts, 2*time.Second, "first heartbeat attempt")
|
||||
if firstHeartbeat.NodeID != "node/404" {
|
||||
t.Fatalf("first heartbeat node id = %q, want node/404", firstHeartbeat.NodeID)
|
||||
}
|
||||
|
||||
secondHeartbeat := awaitValue(t, heartbeatAttempts, 2*time.Second, "second heartbeat attempt")
|
||||
if secondHeartbeat.NodeID != "node/404" {
|
||||
t.Fatalf("second heartbeat node id = %q, want node/404", secondHeartbeat.NodeID)
|
||||
}
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
if extraAttempts := len(heartbeatAttempts); extraAttempts != 0 {
|
||||
t.Fatalf("heartbeat attempts after unsupported route = %d, want 0", extraAttempts)
|
||||
}
|
||||
if extraRegistrations := len(registerRequests); extraRegistrations != 0 {
|
||||
t.Fatalf("register attempts after unsupported route = %d, want 0", extraRegistrations)
|
||||
}
|
||||
|
||||
putRequest := mustRequest(t, http.MethodPut, baseURL+"/dav/after-heartbeat.txt", strings.NewReader("still-serving"))
|
||||
putResponse, err := http.DefaultClient.Do(putRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("put after heartbeat failure: %v", err)
|
||||
}
|
||||
defer putResponse.Body.Close()
|
||||
|
||||
if putResponse.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("put status after heartbeat failure = %d, want 201", putResponse.StatusCode)
|
||||
}
|
||||
|
||||
savedBytes, err := os.ReadFile(filepath.Join(exportPath, "after-heartbeat.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read saved file after heartbeat failure: %v", err)
|
||||
}
|
||||
|
||||
if string(savedBytes) != "still-serving" {
|
||||
t.Fatalf("saved bytes after heartbeat failure = %q, want still-serving", string(savedBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func startTestApp(t *testing.T, cfg Config) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ListenAddress == "" {
|
||||
cfg.ListenAddress = listener.Addr().String()
|
||||
}
|
||||
if cfg.DirectAddress == "" {
|
||||
cfg.DirectAddress = "http://" + listener.Addr().String()
|
||||
}
|
||||
|
||||
stop := serveWithListener(t, listener, cfg)
|
||||
return "http://" + listener.Addr().String(), stop
|
||||
}
|
||||
|
||||
func serveWithListener(t *testing.T, listener net.Listener, cfg Config) func() {
|
||||
t.Helper()
|
||||
|
||||
if cfg.ListenAddress == "" {
|
||||
cfg.ListenAddress = listener.Addr().String()
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(cfg.ExportPath, 0o755); err != nil {
|
||||
listener.Close()
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
|
||||
app, err := New(cfg, log.New(io.Discard, "", 0))
|
||||
if err != nil {
|
||||
listener.Close()
|
||||
t.Fatalf("new app: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
serverErrors := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
serverErrors <- app.Serve(ctx, listener)
|
||||
}()
|
||||
|
||||
waitForCondition(t, 2*time.Second, "app health", func() bool {
|
||||
response, err := http.Get("http://" + listener.Addr().String() + "/health")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
return response.StatusCode == http.StatusOK
|
||||
})
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
if err := <-serverErrors; err != nil {
|
||||
t.Fatalf("serve app: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doWebDAVRequest(baseURL, method, requestPath string, body io.Reader) (*http.Response, error) {
|
||||
request, err := http.NewRequest(method, baseURL+requestPath, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return http.DefaultClient.Do(request)
|
||||
}
|
||||
|
||||
func mustRequest(t *testing.T, method, target string, body io.Reader) *http.Request {
|
||||
t.Helper()
|
||||
|
||||
request, err := http.NewRequest(method, target, body)
|
||||
if err != nil {
|
||||
t.Fatalf("build request %s %s: %v", method, target, err)
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func awaitValue[T any](t *testing.T, values <-chan T, timeout time.Duration, label string) T {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case value := <-values:
|
||||
return value
|
||||
case <-time.After(timeout):
|
||||
t.Fatalf("timed out waiting for %s", label)
|
||||
var zero T
|
||||
return zero
|
||||
}
|
||||
}
|
||||
|
||||
func waitForCondition(t *testing.T, timeout time.Duration, label string, check func() bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if check() {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("timed out waiting for %s", label)
|
||||
}
|
||||
95
apps/node-agent/internal/nodeagent/app_test.go
Normal file
95
apps/node-agent/internal/nodeagent/app_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewRejectsMissingExportDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "missing-export")
|
||||
|
||||
_, err := New(Config{
|
||||
ExportPath: exportPath,
|
||||
ListenAddress: defaultListenAddress(defaultPort),
|
||||
}, log.New(io.Discard, "", 0))
|
||||
if err == nil {
|
||||
t.Fatal("expected missing export directory to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "does not exist") {
|
||||
t.Fatalf("error = %q, want missing-directory message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRejectsFileExportPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exportPath := filepath.Join(t.TempDir(), "export.txt")
|
||||
if err := os.WriteFile(exportPath, []byte("not a directory"), 0o644); err != nil {
|
||||
t.Fatalf("write export file: %v", err)
|
||||
}
|
||||
|
||||
_, err := New(Config{
|
||||
ExportPath: exportPath,
|
||||
ListenAddress: defaultListenAddress(defaultPort),
|
||||
}, log.New(io.Discard, "", 0))
|
||||
if err == nil {
|
||||
t.Fatal("expected file export path to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "is not a directory") {
|
||||
t.Fatalf("error = %q, want not-a-directory message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRejectsInvalidListenAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := New(Config{
|
||||
ExportPath: t.TempDir(),
|
||||
ListenAddress: "localhost",
|
||||
}, log.New(io.Discard, "", 0))
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid listen address to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), listenAddressEnvKey) {
|
||||
t.Fatalf("error = %q, want %q", err.Error(), listenAddressEnvKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAcceptsLoopbackListenAddressByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := New(Config{
|
||||
ExportPath: t.TempDir(),
|
||||
ListenAddress: defaultListenAddress(defaultPort),
|
||||
}, log.New(io.Discard, "", 0))
|
||||
if err != nil {
|
||||
t.Fatalf("new app: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRejectsRegistrationWithoutMachineID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := New(Config{
|
||||
ExportPath: t.TempDir(),
|
||||
ListenAddress: defaultListenAddress(defaultPort),
|
||||
RegisterEnabled: true,
|
||||
ControlPlaneURL: "http://127.0.0.1:8081",
|
||||
}, log.New(io.Discard, "", 0))
|
||||
if err == nil {
|
||||
t.Fatal("expected missing machine id to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "BETTERNAS_NODE_MACHINE_ID") {
|
||||
t.Fatalf("error = %q, want missing-machine-id message", err.Error())
|
||||
}
|
||||
}
|
||||
23
apps/node-agent/internal/nodeagent/capacity_supported.go
Normal file
23
apps/node-agent/internal/nodeagent/capacity_supported.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
//go:build android || darwin || dragonfly || freebsd || illumos || ios || linux || netbsd || openbsd || solaris
|
||||
|
||||
package nodeagent
|
||||
|
||||
import (
|
||||
"math"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func detectCapacityBytes(path string) *int64 {
|
||||
var stats syscall.Statfs_t
|
||||
if err := syscall.Statfs(path, &stats); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
capacity := uint64(stats.Blocks) * uint64(stats.Bsize)
|
||||
if capacity > math.MaxInt64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := int64(capacity)
|
||||
return &value
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
//go:build !(android || darwin || dragonfly || freebsd || illumos || ios || linux || netbsd || openbsd || solaris)
|
||||
|
||||
package nodeagent
|
||||
|
||||
func detectCapacityBytes(string) *int64 {
|
||||
return nil
|
||||
}
|
||||
357
apps/node-agent/internal/nodeagent/config.go
Normal file
357
apps/node-agent/internal/nodeagent/config.go
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = "8090"
|
||||
defaultAgentVersion = "0.1.0-dev"
|
||||
defaultHeartbeatInterval = 30 * time.Second
|
||||
defaultListenHost = "127.0.0.1"
|
||||
exportPathEnvKey = "BETTERNAS_EXPORT_PATH"
|
||||
listenAddressEnvKey = "BETTERNAS_NODE_LISTEN_ADDRESS"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
ListenAddress string
|
||||
ExportPath string
|
||||
MachineID string
|
||||
DisplayName string
|
||||
AgentVersion string
|
||||
DirectAddress string
|
||||
RelayAddress string
|
||||
ExportLabel string
|
||||
ExportTags []string
|
||||
ControlPlaneURL string
|
||||
ControlPlaneToken string
|
||||
RegisterEnabled bool
|
||||
HeartbeatEnabled bool
|
||||
HeartbeatInterval time.Duration
|
||||
}
|
||||
|
||||
type envLookup func(string) (string, bool)
|
||||
|
||||
func LoadConfigFromEnv() (Config, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("get working directory: %w", err)
|
||||
}
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil || hostname == "" {
|
||||
hostname = "betternas-node"
|
||||
}
|
||||
|
||||
return loadConfig(os.LookupEnv, cwd, hostname)
|
||||
}
|
||||
|
||||
func loadConfig(lookup envLookup, cwd, hostname string) (Config, error) {
|
||||
port := envOrDefault(lookup, "PORT", defaultPort)
|
||||
|
||||
rawExportPath, err := envRequired(lookup, exportPathEnvKey)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
exportPath, err := resolveExportPath(rawExportPath, cwd)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
listenAddress := envOrDefault(lookup, listenAddressEnvKey, defaultListenAddress(port))
|
||||
|
||||
registerEnabled, err := envBool(lookup, "BETTERNAS_NODE_REGISTER_ENABLED", false)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
heartbeatEnabled, err := envBool(lookup, "BETTERNAS_NODE_HEARTBEAT_ENABLED", false)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
heartbeatInterval, err := envDuration(lookup, "BETTERNAS_NODE_HEARTBEAT_INTERVAL", defaultHeartbeatInterval)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
machineID, machineIDProvided := envOptional(lookup, "BETTERNAS_NODE_MACHINE_ID")
|
||||
if !machineIDProvided {
|
||||
machineID = hostname
|
||||
}
|
||||
|
||||
if registerEnabled && !machineIDProvided {
|
||||
return Config{}, fmt.Errorf("BETTERNAS_NODE_MACHINE_ID is required when BETTERNAS_NODE_REGISTER_ENABLED=true")
|
||||
}
|
||||
|
||||
displayName := envOrDefault(lookup, "BETTERNAS_NODE_DISPLAY_NAME", machineID)
|
||||
agentVersion := envOrDefault(lookup, "BETTERNAS_VERSION", defaultAgentVersion)
|
||||
directAddress := envOrDefault(lookup, "BETTERNAS_NODE_DIRECT_ADDRESS", defaultDirectAddress(listenAddress, port))
|
||||
relayAddress := envOrDefault(lookup, "BETTERNAS_NODE_RELAY_ADDRESS", "")
|
||||
exportLabel := envOrDefault(lookup, "BETTERNAS_EXPORT_LABEL", defaultExportLabel(exportPath))
|
||||
exportTags := parseCSVList(envOrDefault(lookup, "BETTERNAS_EXPORT_TAGS", ""))
|
||||
controlPlaneURL := strings.TrimRight(envOrDefault(lookup, "BETTERNAS_CONTROL_PLANE_URL", ""), "/")
|
||||
controlPlaneToken := envOrDefault(lookup, "BETTERNAS_CONTROL_PLANE_AUTH_TOKEN", "")
|
||||
|
||||
cfg := Config{
|
||||
Port: port,
|
||||
ListenAddress: listenAddress,
|
||||
ExportPath: exportPath,
|
||||
MachineID: machineID,
|
||||
DisplayName: displayName,
|
||||
AgentVersion: agentVersion,
|
||||
DirectAddress: directAddress,
|
||||
RelayAddress: relayAddress,
|
||||
ExportLabel: exportLabel,
|
||||
ExportTags: exportTags,
|
||||
ControlPlaneURL: controlPlaneURL,
|
||||
ControlPlaneToken: controlPlaneToken,
|
||||
RegisterEnabled: registerEnabled,
|
||||
HeartbeatEnabled: heartbeatEnabled,
|
||||
HeartbeatInterval: heartbeatInterval,
|
||||
}
|
||||
|
||||
if err := validateRuntimeConfig(cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func resolveExportPath(rawPath, cwd string) (string, error) {
|
||||
exportPath := strings.TrimSpace(rawPath)
|
||||
if exportPath == "" {
|
||||
return "", fmt.Errorf("export path is required")
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(exportPath) {
|
||||
basePath := cwd
|
||||
if workspaceRoot, ok := findWorkspaceRoot(cwd); ok {
|
||||
basePath = workspaceRoot
|
||||
}
|
||||
|
||||
exportPath = filepath.Join(basePath, exportPath)
|
||||
}
|
||||
|
||||
absolutePath, err := filepath.Abs(exportPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve export path %q: %w", rawPath, err)
|
||||
}
|
||||
|
||||
return filepath.Clean(absolutePath), nil
|
||||
}
|
||||
|
||||
func envRequired(lookup envLookup, key string) (string, error) {
|
||||
value, ok := lookup(key)
|
||||
if !ok || strings.TrimSpace(value) == "" {
|
||||
return "", fmt.Errorf("%s is required", key)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
|
||||
func envOptional(lookup envLookup, key string) (string, bool) {
|
||||
value, ok := lookup(key)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return trimmed, true
|
||||
}
|
||||
|
||||
func defaultListenAddress(port string) string {
|
||||
return net.JoinHostPort(defaultListenHost, port)
|
||||
}
|
||||
|
||||
func defaultDirectAddress(listenAddress, fallbackPort string) string {
|
||||
if strings.TrimSpace(listenAddress) == defaultListenAddress(fallbackPort) {
|
||||
return httpURL("localhost", fallbackPort)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(strings.TrimSpace(listenAddress))
|
||||
if err != nil || strings.TrimSpace(port) == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
host = strings.TrimSpace(host)
|
||||
if isWildcardListenHost(host) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return httpURL(host, port)
|
||||
}
|
||||
|
||||
func isWildcardListenHost(host string) bool {
|
||||
trimmed := strings.TrimSpace(host)
|
||||
if trimmed == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(trimmed)
|
||||
return ip != nil && ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func httpURL(host, port string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(host, port),
|
||||
}).String()
|
||||
}
|
||||
|
||||
func findWorkspaceRoot(start string) (string, bool) {
|
||||
current := filepath.Clean(start)
|
||||
|
||||
for {
|
||||
if hasPath(filepath.Join(current, "pnpm-workspace.yaml")) || hasPath(filepath.Join(current, "go.work")) || hasPath(filepath.Join(current, ".git")) {
|
||||
return current, true
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
return "", false
|
||||
}
|
||||
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
|
||||
func hasPath(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func defaultExportLabel(exportPath string) string {
|
||||
label := filepath.Base(exportPath)
|
||||
if label == "." || label == string(filepath.Separator) || label == "" {
|
||||
return "export"
|
||||
}
|
||||
|
||||
return label
|
||||
}
|
||||
|
||||
func parseCSVList(raw string) []string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
values := make([]string, 0)
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, part := range strings.Split(raw, ",") {
|
||||
value := strings.TrimSpace(part)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[value] = struct{}{}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func envOrDefault(lookup envLookup, key, fallback string) string {
|
||||
value, ok := lookup(key)
|
||||
if !ok {
|
||||
return fallback
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func validateRuntimeConfig(cfg Config) error {
|
||||
if err := validateListenAddress(cfg.ListenAddress); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.RegisterEnabled && strings.TrimSpace(cfg.ControlPlaneURL) == "" {
|
||||
return fmt.Errorf("BETTERNAS_CONTROL_PLANE_URL is required when BETTERNAS_NODE_REGISTER_ENABLED=true")
|
||||
}
|
||||
|
||||
if cfg.RegisterEnabled && strings.TrimSpace(cfg.MachineID) == "" {
|
||||
return fmt.Errorf("BETTERNAS_NODE_MACHINE_ID is required when BETTERNAS_NODE_REGISTER_ENABLED=true")
|
||||
}
|
||||
|
||||
if cfg.HeartbeatEnabled && !cfg.RegisterEnabled {
|
||||
return fmt.Errorf("BETTERNAS_NODE_HEARTBEAT_ENABLED requires BETTERNAS_NODE_REGISTER_ENABLED=true")
|
||||
}
|
||||
|
||||
if cfg.HeartbeatEnabled && cfg.HeartbeatInterval <= 0 {
|
||||
return fmt.Errorf("BETTERNAS_NODE_HEARTBEAT_INTERVAL must be greater than zero")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateListenAddress(address string) error {
|
||||
trimmed := strings.TrimSpace(address)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%s is required", listenAddressEnvKey)
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(trimmed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse %s: %w", listenAddressEnvKey, err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(port) == "" {
|
||||
return fmt.Errorf("%s must include a port", listenAddressEnvKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func envBool(lookup envLookup, key string, fallback bool) (bool, error) {
|
||||
value, ok := lookup(key)
|
||||
if !ok || strings.TrimSpace(value) == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseBool(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse %s: %w", key, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func envDuration(lookup envLookup, key string, fallback time.Duration) (time.Duration, error) {
|
||||
value, ok := lookup(key)
|
||||
if !ok || strings.TrimSpace(value) == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
parsed, err := time.ParseDuration(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse %s: %w", key, err)
|
||||
}
|
||||
|
||||
if parsed <= 0 {
|
||||
return 0, fmt.Errorf("%s must be greater than zero", key)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
404
apps/node-agent/internal/nodeagent/config_test.go
Normal file
404
apps/node-agent/internal/nodeagent/config_test.go
Normal file
|
|
@ -0,0 +1,404 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadConfigResolvesRelativeExportPathFromWorkspaceRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repoRoot := t.TempDir()
|
||||
agentDir := filepath.Join(repoRoot, "apps", "node-agent")
|
||||
|
||||
if err := os.MkdirAll(agentDir, 0o755); err != nil {
|
||||
t.Fatalf("create agent dir: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(repoRoot, "pnpm-workspace.yaml"), []byte("packages:\n - apps/*\n"), 0o644); err != nil {
|
||||
t.Fatalf("write workspace file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"BETTERNAS_NODE_MACHINE_ID": "nas-machine-id",
|
||||
"BETTERNAS_EXPORT_TAGS": "finder, photos, finder",
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED": "true",
|
||||
"BETTERNAS_NODE_HEARTBEAT_ENABLED": "true",
|
||||
"BETTERNAS_CONTROL_PLANE_URL": "http://127.0.0.1:8081/",
|
||||
"BETTERNAS_CONTROL_PLANE_AUTH_TOKEN": "node-auth-token",
|
||||
"BETTERNAS_NODE_HEARTBEAT_INTERVAL": "45s",
|
||||
}),
|
||||
agentDir,
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
expectedExportPath := filepath.Join(repoRoot, ".state", "nas", "export")
|
||||
if cfg.ExportPath != expectedExportPath {
|
||||
t.Fatalf("export path = %q, want %q", cfg.ExportPath, expectedExportPath)
|
||||
}
|
||||
|
||||
if cfg.ListenAddress != defaultListenAddress(defaultPort) {
|
||||
t.Fatalf("listen address = %q, want %q", cfg.ListenAddress, defaultListenAddress(defaultPort))
|
||||
}
|
||||
|
||||
if cfg.MachineID != "nas-machine-id" {
|
||||
t.Fatalf("machine id = %q, want nas-machine-id", cfg.MachineID)
|
||||
}
|
||||
|
||||
if cfg.DisplayName != "nas-machine-id" {
|
||||
t.Fatalf("display name = %q, want nas-machine-id", cfg.DisplayName)
|
||||
}
|
||||
|
||||
if cfg.DirectAddress != "http://localhost:8090" {
|
||||
t.Fatalf("direct address = %q, want loopback default", cfg.DirectAddress)
|
||||
}
|
||||
|
||||
if cfg.ExportLabel != "export" {
|
||||
t.Fatalf("export label = %q, want export", cfg.ExportLabel)
|
||||
}
|
||||
|
||||
if len(cfg.ExportTags) != 2 || cfg.ExportTags[0] != "finder" || cfg.ExportTags[1] != "photos" {
|
||||
t.Fatalf("export tags = %#v, want [finder photos]", cfg.ExportTags)
|
||||
}
|
||||
|
||||
if !cfg.RegisterEnabled {
|
||||
t.Fatalf("register enabled = false, want true")
|
||||
}
|
||||
|
||||
if !cfg.HeartbeatEnabled {
|
||||
t.Fatalf("heartbeat enabled = false, want true")
|
||||
}
|
||||
|
||||
if cfg.HeartbeatInterval != 45*time.Second {
|
||||
t.Fatalf("heartbeat interval = %s, want 45s", cfg.HeartbeatInterval)
|
||||
}
|
||||
|
||||
if cfg.ControlPlaneURL != "http://127.0.0.1:8081" {
|
||||
t.Fatalf("control plane url = %q, want trimmed url", cfg.ControlPlaneURL)
|
||||
}
|
||||
|
||||
if cfg.ControlPlaneToken != "node-auth-token" {
|
||||
t.Fatalf("control plane token = %q, want node-auth-token", cfg.ControlPlaneToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDefaultsRegistrationToDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RegisterEnabled {
|
||||
t.Fatal("register enabled = true, want false")
|
||||
}
|
||||
|
||||
if cfg.HeartbeatEnabled {
|
||||
t.Fatal("heartbeat enabled = true, want false")
|
||||
}
|
||||
|
||||
if cfg.ControlPlaneURL != "" {
|
||||
t.Fatalf("control plane url = %q, want empty", cfg.ControlPlaneURL)
|
||||
}
|
||||
|
||||
if cfg.MachineID != "nas-box" {
|
||||
t.Fatalf("machine id = %q, want nas-box", cfg.MachineID)
|
||||
}
|
||||
|
||||
if cfg.ListenAddress != defaultListenAddress(defaultPort) {
|
||||
t.Fatalf("listen address = %q, want %q", cfg.ListenAddress, defaultListenAddress(defaultPort))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDefaultsHeartbeatToDisabledEvenWhenRegistrationEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"BETTERNAS_NODE_MACHINE_ID": "nas-machine-id",
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED": "true",
|
||||
"BETTERNAS_CONTROL_PLANE_URL": "http://127.0.0.1:8081",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.RegisterEnabled {
|
||||
t.Fatal("register enabled = false, want true")
|
||||
}
|
||||
|
||||
if cfg.HeartbeatEnabled {
|
||||
t.Fatal("heartbeat enabled = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigRejectsHeartbeatWithoutRegistration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"BETTERNAS_NODE_HEARTBEAT_ENABLED": "true",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected heartbeat-only config to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigRequiresExportPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
values map[string]string
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
values: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "blank",
|
||||
values: map[string]string{
|
||||
exportPathEnvKey: " ",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadConfig(mapLookup(testCase.values), t.TempDir(), "nas-box")
|
||||
if err == nil {
|
||||
t.Fatal("expected missing export path to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), exportPathEnvKey) {
|
||||
t.Fatalf("error = %q, want %q", err.Error(), exportPathEnvKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDefaultsListenAddressToLoopback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"PORT": "9100",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ListenAddress != "127.0.0.1:9100" {
|
||||
t.Fatalf("listen address = %q, want 127.0.0.1:9100", cfg.ListenAddress)
|
||||
}
|
||||
|
||||
if cfg.DirectAddress != "http://localhost:9100" {
|
||||
t.Fatalf("direct address = %q, want http://localhost:9100", cfg.DirectAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigUsesExplicitWildcardListenAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
listenAddressEnvKey: ":9090",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ListenAddress != ":9090" {
|
||||
t.Fatalf("listen address = %q, want :9090", cfg.ListenAddress)
|
||||
}
|
||||
|
||||
if cfg.DirectAddress != "" {
|
||||
t.Fatalf("direct address = %q, want empty for wildcard listener", cfg.DirectAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDerivesDirectAddressFromExplicitHostListenAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
listenAddressEnvKey: "192.0.2.10:9443",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.DirectAddress != "http://192.0.2.10:9443" {
|
||||
t.Fatalf("direct address = %q, want http://192.0.2.10:9443", cfg.DirectAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDoesNotDeriveDirectAddressFromWildcardHostListenAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
listenAddress string
|
||||
}{
|
||||
{
|
||||
name: "ipv4 wildcard",
|
||||
listenAddress: "0.0.0.0:9443",
|
||||
},
|
||||
{
|
||||
name: "ipv6 wildcard",
|
||||
listenAddress: "[::]:9443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
listenAddressEnvKey: testCase.listenAddress,
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.DirectAddress != "" {
|
||||
t.Fatalf("direct address = %q, want empty for %q", cfg.DirectAddress, testCase.listenAddress)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigRejectsInvalidListenAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
listenAddressEnvKey: "localhost",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid listen address to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), listenAddressEnvKey) {
|
||||
t.Fatalf("error = %q, want %q", err.Error(), listenAddressEnvKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigAllowsRegistrationWithoutControlPlaneToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"BETTERNAS_NODE_MACHINE_ID": "nas-machine-id",
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED": "true",
|
||||
"BETTERNAS_CONTROL_PLANE_URL": "http://127.0.0.1:8081",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ControlPlaneToken != "" {
|
||||
t.Fatalf("control-plane token = %q, want empty", cfg.ControlPlaneToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigRejectsRegistrationWithoutMachineID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED": "true",
|
||||
"BETTERNAS_CONTROL_PLANE_URL": "http://127.0.0.1:8081",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected missing machine id to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "BETTERNAS_NODE_MACHINE_ID") {
|
||||
t.Fatalf("error = %q, want missing-machine-id message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigRejectsRegistrationWithoutControlPlaneURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadConfig(
|
||||
mapLookup(map[string]string{
|
||||
exportPathEnvKey: ".state/nas/export",
|
||||
"BETTERNAS_NODE_MACHINE_ID": "nas-machine-id",
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED": "true",
|
||||
}),
|
||||
t.TempDir(),
|
||||
"nas-box",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected missing control-plane url to fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "BETTERNAS_CONTROL_PLANE_URL") {
|
||||
t.Fatalf("error = %q, want missing-control-plane-url message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func mapLookup(values map[string]string) envLookup {
|
||||
return func(key string) (string, bool) {
|
||||
value, ok := values[key]
|
||||
return value, ok
|
||||
}
|
||||
}
|
||||
132
apps/node-agent/internal/nodeagent/filesystem.go
Normal file
132
apps/node-agent/internal/nodeagent/filesystem.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/webdav"
|
||||
)
|
||||
|
||||
type exportFileSystem struct {
|
||||
root *os.Root
|
||||
}
|
||||
|
||||
var _ webdav.FileSystem = (*exportFileSystem)(nil)
|
||||
|
||||
func newExportFileSystem(rootPath string) (*exportFileSystem, error) {
|
||||
root, err := os.OpenRoot(rootPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open export root %s: %w", rootPath, err)
|
||||
}
|
||||
|
||||
return &exportFileSystem{
|
||||
root: root,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *exportFileSystem) Close() error {
|
||||
if f.root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := f.root.Close()
|
||||
f.root = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *exportFileSystem) Mkdir(_ context.Context, name string, perm os.FileMode) error {
|
||||
resolvedName, err := resolveExportName(name)
|
||||
if err != nil {
|
||||
return pathError("mkdir", name, err)
|
||||
}
|
||||
|
||||
if resolvedName == "." {
|
||||
return pathError("mkdir", name, os.ErrInvalid)
|
||||
}
|
||||
|
||||
return f.root.Mkdir(resolvedName, perm)
|
||||
}
|
||||
|
||||
func (f *exportFileSystem) OpenFile(_ context.Context, name string, flag int, perm os.FileMode) (webdav.File, error) {
|
||||
resolvedName, err := resolveExportName(name)
|
||||
if err != nil {
|
||||
return nil, pathError("open", name, err)
|
||||
}
|
||||
|
||||
file, err := f.root.OpenFile(resolvedName, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (f *exportFileSystem) RemoveAll(_ context.Context, name string) error {
|
||||
resolvedName, err := resolveExportName(name)
|
||||
if err != nil {
|
||||
return pathError("removeall", name, err)
|
||||
}
|
||||
|
||||
if resolvedName == "." {
|
||||
return pathError("removeall", name, os.ErrInvalid)
|
||||
}
|
||||
|
||||
return f.root.RemoveAll(resolvedName)
|
||||
}
|
||||
|
||||
func (f *exportFileSystem) Rename(_ context.Context, oldName, newName string) error {
|
||||
resolvedOldName, err := resolveExportName(oldName)
|
||||
if err != nil {
|
||||
return pathError("rename", oldName, err)
|
||||
}
|
||||
|
||||
resolvedNewName, err := resolveExportName(newName)
|
||||
if err != nil {
|
||||
return pathError("rename", newName, err)
|
||||
}
|
||||
|
||||
if resolvedOldName == "." || resolvedNewName == "." {
|
||||
return pathError("rename", oldName, os.ErrInvalid)
|
||||
}
|
||||
|
||||
return f.root.Rename(resolvedOldName, resolvedNewName)
|
||||
}
|
||||
|
||||
func (f *exportFileSystem) Stat(_ context.Context, name string) (os.FileInfo, error) {
|
||||
resolvedName, err := resolveExportName(name)
|
||||
if err != nil {
|
||||
return nil, pathError("stat", name, err)
|
||||
}
|
||||
|
||||
return f.root.Stat(resolvedName)
|
||||
}
|
||||
|
||||
func resolveExportName(name string) (string, error) {
|
||||
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
if strings.Contains(name, "\x00") {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
cleanedName := path.Clean("/" + name)
|
||||
cleanedName = strings.TrimPrefix(cleanedName, "/")
|
||||
if cleanedName == "" {
|
||||
return ".", nil
|
||||
}
|
||||
|
||||
return filepath.FromSlash(cleanedName), nil
|
||||
}
|
||||
|
||||
func pathError(op, path string, err error) error {
|
||||
return &os.PathError{
|
||||
Op: op,
|
||||
Path: path,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
309
apps/node-agent/internal/nodeagent/registration.go
Normal file
309
apps/node-agent/internal/nodeagent/registration.go
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
registerNodeRoute = "/api/v1/nodes/register"
|
||||
controlPlaneTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type registrationLoop struct {
|
||||
cfg Config
|
||||
logger *log.Logger
|
||||
client *http.Client
|
||||
nodeID string
|
||||
heartbeatUnsupported bool
|
||||
}
|
||||
|
||||
type nodeRegistrationRequest struct {
|
||||
MachineID string `json:"machineId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
AgentVersion string `json:"agentVersion"`
|
||||
DirectAddress *string `json:"directAddress"`
|
||||
RelayAddress *string `json:"relayAddress"`
|
||||
Exports []storageExportInput `json:"exports"`
|
||||
}
|
||||
|
||||
type storageExportInput struct {
|
||||
Label string `json:"label"`
|
||||
Path string `json:"path"`
|
||||
Protocols []string `json:"protocols"`
|
||||
CapacityBytes *int64 `json:"capacityBytes"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
type nodeRegistrationResponse struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type nodeHeartbeatRequest struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
Status string `json:"status"`
|
||||
LastSeenAt string `json:"lastSeenAt"`
|
||||
}
|
||||
|
||||
type responseStatusError struct {
|
||||
route string
|
||||
statusCode int
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *responseStatusError) Error() string {
|
||||
return fmt.Sprintf("%s returned %d: %s", e.route, e.statusCode, e.message)
|
||||
}
|
||||
|
||||
func newRegistrationLoop(cfg Config, logger *log.Logger) *registrationLoop {
|
||||
return ®istrationLoop{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
client: &http.Client{Timeout: controlPlaneTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registrationLoop) Run(ctx context.Context) {
|
||||
timer := time.NewTimer(0)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
r.syncOnce(ctx)
|
||||
if r.nodeID != "" && (!r.cfg.HeartbeatEnabled || r.heartbeatUnsupported) {
|
||||
return
|
||||
}
|
||||
|
||||
timer.Reset(r.cfg.HeartbeatInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registrationLoop) syncOnce(ctx context.Context) {
|
||||
if r.nodeID == "" {
|
||||
if err := r.registerAndStore(ctx, "betterNAS node agent registered as %s"); err != nil {
|
||||
r.logger.Printf("betterNAS node agent registration failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !r.cfg.HeartbeatEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.sendHeartbeat(ctx); err != nil {
|
||||
if heartbeatRouteUnsupported(err) {
|
||||
r.heartbeatUnsupported = true
|
||||
r.logger.Printf("betterNAS node agent heartbeat route is unavailable; stopping heartbeats: %v", err)
|
||||
return
|
||||
}
|
||||
if heartbeatRequiresRegistrationRefresh(err) {
|
||||
if err := r.recoverFromRejectedHeartbeat(ctx, err); err != nil {
|
||||
r.logger.Printf("betterNAS node agent %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
r.logger.Printf("betterNAS node agent heartbeat failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registrationLoop) registerAndStore(ctx context.Context, message string) error {
|
||||
nodeID, err := r.register(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.nodeID = nodeID
|
||||
if strings.TrimSpace(message) != "" {
|
||||
r.logger.Printf(message, r.nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registrationLoop) recoverFromRejectedHeartbeat(ctx context.Context, heartbeatErr error) error {
|
||||
rejectedNodeID := r.nodeID
|
||||
r.logger.Printf("betterNAS node agent heartbeat was rejected for %s; re-registering: %v", rejectedNodeID, heartbeatErr)
|
||||
r.nodeID = ""
|
||||
|
||||
if err := r.registerAndStore(ctx, "betterNAS node agent re-registered as %s after heartbeat rejection"); err != nil {
|
||||
return fmt.Errorf("failed to re-register after heartbeat rejection: %w", err)
|
||||
}
|
||||
|
||||
if err := r.sendHeartbeat(ctx); err != nil {
|
||||
if heartbeatRouteUnsupported(err) || heartbeatRequiresRegistrationRefresh(err) {
|
||||
r.heartbeatUnsupported = true
|
||||
return fmt.Errorf("heartbeat route did not accept the freshly registered node; stopping heartbeats: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("heartbeat failed after re-registration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registrationLoop) register(ctx context.Context) (string, error) {
|
||||
request := r.registrationRequest()
|
||||
|
||||
var response nodeRegistrationResponse
|
||||
if err := r.postJSON(ctx, registerNodeRoute, request, http.StatusOK, &response); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if strings.TrimSpace(response.ID) == "" {
|
||||
return "", fmt.Errorf("register response did not include a node id")
|
||||
}
|
||||
|
||||
return response.ID, nil
|
||||
}
|
||||
|
||||
func (r *registrationLoop) registrationRequest() nodeRegistrationRequest {
|
||||
machineID := strings.TrimSpace(r.cfg.MachineID)
|
||||
displayName := strings.TrimSpace(r.cfg.DisplayName)
|
||||
if displayName == "" {
|
||||
displayName = machineID
|
||||
}
|
||||
|
||||
agentVersion := strings.TrimSpace(r.cfg.AgentVersion)
|
||||
if agentVersion == "" {
|
||||
agentVersion = defaultAgentVersion
|
||||
}
|
||||
|
||||
exportLabel := strings.TrimSpace(r.cfg.ExportLabel)
|
||||
if exportLabel == "" {
|
||||
exportLabel = defaultExportLabel(r.cfg.ExportPath)
|
||||
}
|
||||
|
||||
return nodeRegistrationRequest{
|
||||
MachineID: machineID,
|
||||
DisplayName: displayName,
|
||||
AgentVersion: agentVersion,
|
||||
DirectAddress: optionalString(r.cfg.DirectAddress),
|
||||
RelayAddress: optionalString(r.cfg.RelayAddress),
|
||||
Exports: []storageExportInput{
|
||||
{
|
||||
Label: exportLabel,
|
||||
Path: r.cfg.ExportPath,
|
||||
Protocols: []string{"webdav"},
|
||||
CapacityBytes: detectCapacityBytes(r.cfg.ExportPath),
|
||||
Tags: cloneStringSlice(r.cfg.ExportTags),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registrationLoop) sendHeartbeat(ctx context.Context) error {
|
||||
request := nodeHeartbeatRequest{
|
||||
NodeID: r.nodeID,
|
||||
Status: "online",
|
||||
LastSeenAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return r.postJSON(ctx, heartbeatRoute(r.nodeID), request, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
func heartbeatRoute(nodeID string) string {
|
||||
return "/api/v1/nodes/" + url.PathEscape(nodeID) + "/heartbeat"
|
||||
}
|
||||
|
||||
func (r *registrationLoop) postJSON(ctx context.Context, route string, payload any, wantStatus int, out any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal %s payload: %w", route, err)
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, r.cfg.ControlPlaneURL+route, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s request: %w", route, err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if token := strings.TrimSpace(r.cfg.ControlPlaneToken); token != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
response, err := r.client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("post %s: %w", route, err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != wantStatus {
|
||||
message, readErr := io.ReadAll(io.LimitReader(response.Body, 4*1024))
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("%s returned %d and body read failed: %w", route, response.StatusCode, readErr)
|
||||
}
|
||||
|
||||
return &responseStatusError{
|
||||
route: route,
|
||||
statusCode: response.StatusCode,
|
||||
message: strings.TrimSpace(string(message)),
|
||||
}
|
||||
}
|
||||
|
||||
if out == nil {
|
||||
_, _ = io.Copy(io.Discard, response.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(out); err != nil {
|
||||
return fmt.Errorf("decode %s response: %w", route, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func optionalString(value string) *string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &trimmed
|
||||
}
|
||||
|
||||
func cloneStringSlice(values []string) []string {
|
||||
return append([]string{}, values...)
|
||||
}
|
||||
|
||||
func heartbeatRouteUnsupported(err error) bool {
|
||||
var statusErr *responseStatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch statusErr.statusCode {
|
||||
case http.StatusMethodNotAllowed, http.StatusNotImplemented:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func heartbeatRequiresRegistrationRefresh(err error) bool {
|
||||
var statusErr *responseStatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch statusErr.statusCode {
|
||||
case http.StatusNotFound, http.StatusGone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
210
apps/node-agent/internal/nodeagent/registration_test.go
Normal file
210
apps/node-agent/internal/nodeagent/registration_test.go
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegistrationRequestUsesEmptyTagsArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loop := newRegistrationLoop(Config{
|
||||
MachineID: "nas-1",
|
||||
DisplayName: "NAS 1",
|
||||
AgentVersion: "test-version",
|
||||
ExportPath: t.TempDir(),
|
||||
ExportLabel: "archive",
|
||||
}, log.New(io.Discard, "", 0))
|
||||
|
||||
request := loop.registrationRequest()
|
||||
if request.Exports[0].Tags == nil {
|
||||
t.Fatal("tags slice = nil, want empty slice")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal registration request: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Contains(body, []byte(`"tags":[]`)) {
|
||||
t.Fatalf("registration json = %s, want empty tags array", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatRouteEscapesOpaqueNodeID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := heartbeatRoute("node/123")
|
||||
want := "/api/v1/nodes/node%2F123/heartbeat"
|
||||
if got != want {
|
||||
t.Fatalf("heartbeatRoute returned %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatRouteUnsupportedDetectsDefinitiveUnsupportedRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
err error
|
||||
wantUnsupported bool
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusNotFound,
|
||||
message: "missing",
|
||||
},
|
||||
wantUnsupported: false,
|
||||
},
|
||||
{
|
||||
name: "method not allowed",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusMethodNotAllowed,
|
||||
message: "method not allowed",
|
||||
},
|
||||
wantUnsupported: true,
|
||||
},
|
||||
{
|
||||
name: "not implemented",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusNotImplemented,
|
||||
message: "not implemented",
|
||||
},
|
||||
wantUnsupported: true,
|
||||
},
|
||||
{
|
||||
name: "temporary failure",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusBadGateway,
|
||||
message: "bad gateway",
|
||||
},
|
||||
wantUnsupported: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := heartbeatRouteUnsupported(testCase.err)
|
||||
if got != testCase.wantUnsupported {
|
||||
t.Fatalf("heartbeatRouteUnsupported(%v) = %t, want %t", testCase.err, got, testCase.wantUnsupported)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatRequiresRegistrationRefreshDetectsRejectedNode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
err error
|
||||
wantRefresh bool
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusNotFound,
|
||||
message: "missing",
|
||||
},
|
||||
wantRefresh: true,
|
||||
},
|
||||
{
|
||||
name: "gone",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusGone,
|
||||
message: "gone",
|
||||
},
|
||||
wantRefresh: true,
|
||||
},
|
||||
{
|
||||
name: "temporary failure",
|
||||
err: &responseStatusError{
|
||||
route: heartbeatRoute("node/123"),
|
||||
statusCode: http.StatusBadGateway,
|
||||
message: "bad gateway",
|
||||
},
|
||||
wantRefresh: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := heartbeatRequiresRegistrationRefresh(testCase.err)
|
||||
if got != testCase.wantRefresh {
|
||||
t.Fatalf("heartbeatRequiresRegistrationRefresh(%v) = %t, want %t", testCase.err, got, testCase.wantRefresh)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostJSONAddsBearerAuthorization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestHeaders := make(chan http.Header, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestHeaders <- r.Header.Clone()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"node-1"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
loop := newRegistrationLoop(Config{
|
||||
ControlPlaneURL: server.URL,
|
||||
ControlPlaneToken: "node-auth-token",
|
||||
}, log.New(io.Discard, "", 0))
|
||||
|
||||
var response nodeRegistrationResponse
|
||||
if err := loop.postJSON(context.Background(), registerNodeRoute, nodeRegistrationRequest{}, http.StatusOK, &response); err != nil {
|
||||
t.Fatalf("post json: %v", err)
|
||||
}
|
||||
|
||||
headers := <-requestHeaders
|
||||
if got := headers.Get("Authorization"); got != "Bearer node-auth-token" {
|
||||
t.Fatalf("authorization header = %q, want Bearer token", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostJSONOmitsBearerAuthorizationWhenTokenUnset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestHeaders := make(chan http.Header, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestHeaders <- r.Header.Clone()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"node-1"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
loop := newRegistrationLoop(Config{
|
||||
ControlPlaneURL: server.URL,
|
||||
}, log.New(io.Discard, "", 0))
|
||||
|
||||
var response nodeRegistrationResponse
|
||||
if err := loop.postJSON(context.Background(), registerNodeRoute, nodeRegistrationRequest{}, http.StatusOK, &response); err != nil {
|
||||
t.Fatalf("post json: %v", err)
|
||||
}
|
||||
|
||||
headers := <-requestHeaders
|
||||
if got := headers.Get("Authorization"); got != "" {
|
||||
t.Fatalf("authorization header = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
714
apps/node-agent/internal/nodeagent/runtime_integration_test.go
Normal file
714
apps/node-agent/internal/nodeagent/runtime_integration_test.go
Normal file
|
|
@ -0,0 +1,714 @@
|
|||
package nodeagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type startedProcess struct {
|
||||
cmd *exec.Cmd
|
||||
output *lockedBuffer
|
||||
}
|
||||
|
||||
type lockedBuffer struct {
|
||||
mu sync.Mutex
|
||||
b bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *lockedBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
return b.b.Write(p)
|
||||
}
|
||||
|
||||
func (b *lockedBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
return b.b.String()
|
||||
}
|
||||
|
||||
func TestRuntimeBinaryBindsToLoopbackByDefault(t *testing.T) {
|
||||
repoRoot := testRepoRoot(t)
|
||||
nodeAgentBin := buildNodeAgentBinary(t, repoRoot)
|
||||
nodeAgentPort := freePort(t)
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(exportPath, "seed.txt"), []byte("seed"), 0o644); err != nil {
|
||||
t.Fatalf("write seed file: %v", err)
|
||||
}
|
||||
|
||||
nodeAgentURL := "http://127.0.0.1:" + strconv.Itoa(nodeAgentPort)
|
||||
nodeAgent := startBinaryProcess(t, repoRoot, nodeAgentBin, []string{
|
||||
"PORT=" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_EXPORT_PATH=" + exportPath,
|
||||
})
|
||||
defer nodeAgent.stop(t)
|
||||
|
||||
waitForHTTPStatus(t, nodeAgentURL+"/health", http.StatusOK)
|
||||
|
||||
propfindRequest, err := http.NewRequest("PROPFIND", nodeAgentURL+"/dav/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build propfind request: %v", err)
|
||||
}
|
||||
propfindRequest.Header.Set("Depth", "0")
|
||||
|
||||
propfindResponse, err := http.DefaultClient.Do(propfindRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("propfind WebDAV root: %v", err)
|
||||
}
|
||||
defer propfindResponse.Body.Close()
|
||||
|
||||
if propfindResponse.StatusCode != http.StatusMultiStatus {
|
||||
t.Fatalf("propfind status = %d, want 207", propfindResponse.StatusCode)
|
||||
}
|
||||
|
||||
getResponse, err := http.Get(nodeAgentURL + "/dav/seed.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("get WebDAV file: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
getBody, err := io.ReadAll(getResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read WebDAV body: %v", err)
|
||||
}
|
||||
|
||||
if getResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("get status = %d, want 200", getResponse.StatusCode)
|
||||
}
|
||||
if string(getBody) != "seed" {
|
||||
t.Fatalf("get body = %q, want seed", string(getBody))
|
||||
}
|
||||
|
||||
host, ok := firstNonLoopbackIPv4()
|
||||
if !ok {
|
||||
t.Skip("no non-loopback IPv4 address available to verify loopback-only binding")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 500 * time.Millisecond}
|
||||
_, err = client.Get("http://" + host + ":" + strconv.Itoa(nodeAgentPort) + "/health")
|
||||
if err == nil {
|
||||
t.Fatalf("expected loopback-only listener to reject non-loopback host %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBinaryServesWebDAVWithExplicitListenAddress(t *testing.T) {
|
||||
repoRoot := testRepoRoot(t)
|
||||
nodeAgentBin := buildNodeAgentBinary(t, repoRoot)
|
||||
nodeAgentPort := freePort(t)
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(exportPath, "seed.txt"), []byte("seed"), 0o644); err != nil {
|
||||
t.Fatalf("write seed file: %v", err)
|
||||
}
|
||||
|
||||
nodeAgentURL := "http://127.0.0.1:" + strconv.Itoa(nodeAgentPort)
|
||||
nodeAgent := startBinaryProcess(t, repoRoot, nodeAgentBin, []string{
|
||||
"PORT=" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_EXPORT_PATH=" + exportPath,
|
||||
listenAddressEnvKey + "=:" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_NODE_DIRECT_ADDRESS=" + nodeAgentURL,
|
||||
})
|
||||
defer nodeAgent.stop(t)
|
||||
|
||||
waitForHTTPStatus(t, nodeAgentURL+"/health", http.StatusOK)
|
||||
|
||||
propfindRequest, err := http.NewRequest("PROPFIND", nodeAgentURL+"/dav/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build propfind request: %v", err)
|
||||
}
|
||||
propfindRequest.Header.Set("Depth", "0")
|
||||
|
||||
propfindResponse, err := http.DefaultClient.Do(propfindRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("propfind WebDAV root: %v", err)
|
||||
}
|
||||
defer propfindResponse.Body.Close()
|
||||
|
||||
if propfindResponse.StatusCode != http.StatusMultiStatus {
|
||||
t.Fatalf("propfind status = %d, want 207", propfindResponse.StatusCode)
|
||||
}
|
||||
|
||||
getResponse, err := doRuntimeWebDAVRequest(nodeAgentURL, http.MethodGet, "/dav/seed.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get WebDAV file: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
getBody, err := io.ReadAll(getResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read WebDAV body: %v", err)
|
||||
}
|
||||
|
||||
if getResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("get status = %d, want 200", getResponse.StatusCode)
|
||||
}
|
||||
if string(getBody) != "seed" {
|
||||
t.Fatalf("get body = %q, want seed", string(getBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBinaryOmitsDirectAddressForWildcardListenAddress(t *testing.T) {
|
||||
repoRoot := testRepoRoot(t)
|
||||
nodeAgentBin := buildNodeAgentBinary(t, repoRoot)
|
||||
nodeAgentPort := freePort(t)
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(exportPath, "seed.txt"), []byte("seed"), 0o644); err != nil {
|
||||
t.Fatalf("write seed file: %v", err)
|
||||
}
|
||||
|
||||
registerRequests := make(chan nodeRegistrationRequest, 1)
|
||||
controlPlane := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.EscapedPath() != registerNodeRoute {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var request nodeRegistrationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode register request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
registerRequests <- request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"runtime-node"}`)
|
||||
}))
|
||||
defer controlPlane.Close()
|
||||
|
||||
nodeAgentURL := "http://127.0.0.1:" + strconv.Itoa(nodeAgentPort)
|
||||
nodeAgent := startBinaryProcess(t, repoRoot, nodeAgentBin, []string{
|
||||
"PORT=" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_EXPORT_PATH=" + exportPath,
|
||||
"BETTERNAS_NODE_MACHINE_ID=runtime-machine",
|
||||
"BETTERNAS_CONTROL_PLANE_URL=" + controlPlane.URL,
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED=true",
|
||||
listenAddressEnvKey + "=:" + strconv.Itoa(nodeAgentPort),
|
||||
})
|
||||
defer nodeAgent.stop(t)
|
||||
|
||||
waitForHTTPStatus(t, nodeAgentURL+"/health", http.StatusOK)
|
||||
|
||||
registerRequest := awaitValue(t, registerRequests, 5*time.Second, "register request")
|
||||
if registerRequest.DirectAddress != nil {
|
||||
t.Fatalf("direct address = %#v, want nil for wildcard listener", registerRequest.DirectAddress)
|
||||
}
|
||||
|
||||
getResponse, err := doRuntimeWebDAVRequest(nodeAgentURL, http.MethodGet, "/dav/seed.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get WebDAV file: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
getBody, err := io.ReadAll(getResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read WebDAV body: %v", err)
|
||||
}
|
||||
|
||||
if getResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("get status = %d, want 200", getResponse.StatusCode)
|
||||
}
|
||||
if string(getBody) != "seed" {
|
||||
t.Fatalf("get body = %q, want seed", string(getBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBinaryRejectsInvalidListenAddress(t *testing.T) {
|
||||
repoRoot := testRepoRoot(t)
|
||||
nodeAgentBin := buildNodeAgentBinary(t, repoRoot)
|
||||
nodeAgentPort := freePort(t)
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
|
||||
command := exec.Command(nodeAgentBin)
|
||||
command.Dir = repoRoot
|
||||
command.Env = mergedEnv([]string{
|
||||
"PORT=" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_EXPORT_PATH=" + exportPath,
|
||||
listenAddressEnvKey + "=localhost",
|
||||
})
|
||||
output, err := command.CombinedOutput()
|
||||
if err == nil {
|
||||
t.Fatal("expected node-agent to reject invalid listen address")
|
||||
}
|
||||
|
||||
if !strings.Contains(string(output), listenAddressEnvKey) {
|
||||
t.Fatalf("output = %q, want %q guidance", string(output), listenAddressEnvKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBinaryUsesOptionalControlPlaneSync(t *testing.T) {
|
||||
repoRoot := testRepoRoot(t)
|
||||
nodeAgentBin := buildNodeAgentBinary(t, repoRoot)
|
||||
nodeAgentPort := freePort(t)
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(exportPath, "seed.txt"), []byte("seed"), 0o644); err != nil {
|
||||
t.Fatalf("write seed file: %v", err)
|
||||
}
|
||||
|
||||
const (
|
||||
machineID = "runtime-machine"
|
||||
controlPlaneToken = "runtime-control-plane-token"
|
||||
)
|
||||
|
||||
registerRequests := make(chan nodeRegistrationRequest, 1)
|
||||
heartbeatRequests := make(chan nodeHeartbeatRequest, 4)
|
||||
controlPlane := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer "+controlPlaneToken {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
t.Errorf("authorization header = %q, want Bearer token", got)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.EscapedPath() {
|
||||
case registerNodeRoute:
|
||||
var request nodeRegistrationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode register request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
registerRequests <- request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"runtime-node"}`)
|
||||
case heartbeatRoute("runtime-node"):
|
||||
var request nodeHeartbeatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Errorf("decode heartbeat request: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
heartbeatRequests <- request
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer controlPlane.Close()
|
||||
|
||||
nodeAgentURL := "http://127.0.0.1:" + strconv.Itoa(nodeAgentPort)
|
||||
nodeAgent := startBinaryProcess(t, repoRoot, nodeAgentBin, []string{
|
||||
"PORT=" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_EXPORT_PATH=" + exportPath,
|
||||
"BETTERNAS_VERSION=test-version",
|
||||
"BETTERNAS_NODE_MACHINE_ID=" + machineID,
|
||||
"BETTERNAS_NODE_DISPLAY_NAME=Runtime NAS",
|
||||
"BETTERNAS_EXPORT_LABEL=runtime-export",
|
||||
"BETTERNAS_EXPORT_TAGS=runtime,finder",
|
||||
"BETTERNAS_NODE_DIRECT_ADDRESS=" + nodeAgentURL,
|
||||
"BETTERNAS_CONTROL_PLANE_URL=" + controlPlane.URL,
|
||||
"BETTERNAS_CONTROL_PLANE_AUTH_TOKEN=" + controlPlaneToken,
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED=true",
|
||||
"BETTERNAS_NODE_HEARTBEAT_ENABLED=true",
|
||||
"BETTERNAS_NODE_HEARTBEAT_INTERVAL=100ms",
|
||||
})
|
||||
defer nodeAgent.stop(t)
|
||||
|
||||
waitForHTTPStatus(t, nodeAgentURL+"/health", http.StatusOK)
|
||||
|
||||
getResponse, err := doRuntimeWebDAVRequest(nodeAgentURL, http.MethodGet, "/dav/seed.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get WebDAV file: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
getBody, err := io.ReadAll(getResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read WebDAV body: %v", err)
|
||||
}
|
||||
if getResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("get WebDAV status = %d, want 200", getResponse.StatusCode)
|
||||
}
|
||||
if string(getBody) != "seed" {
|
||||
t.Fatalf("get WebDAV body = %q, want seed", string(getBody))
|
||||
}
|
||||
|
||||
registerRequest := awaitValue(t, registerRequests, 5*time.Second, "register request")
|
||||
if registerRequest.MachineID != machineID {
|
||||
t.Fatalf("machine id = %q, want %q", registerRequest.MachineID, machineID)
|
||||
}
|
||||
if registerRequest.DisplayName != "Runtime NAS" {
|
||||
t.Fatalf("display name = %q, want Runtime NAS", registerRequest.DisplayName)
|
||||
}
|
||||
if registerRequest.AgentVersion != "test-version" {
|
||||
t.Fatalf("agent version = %q, want test-version", registerRequest.AgentVersion)
|
||||
}
|
||||
if registerRequest.DirectAddress == nil || *registerRequest.DirectAddress != nodeAgentURL {
|
||||
t.Fatalf("direct address = %#v, want %q", registerRequest.DirectAddress, nodeAgentURL)
|
||||
}
|
||||
if registerRequest.RelayAddress != nil {
|
||||
t.Fatalf("relay address = %#v, want nil", registerRequest.RelayAddress)
|
||||
}
|
||||
if len(registerRequest.Exports) != 1 {
|
||||
t.Fatalf("exports length = %d, want 1", len(registerRequest.Exports))
|
||||
}
|
||||
if registerRequest.Exports[0].Label != "runtime-export" {
|
||||
t.Fatalf("export label = %q, want runtime-export", registerRequest.Exports[0].Label)
|
||||
}
|
||||
if registerRequest.Exports[0].Path != exportPath {
|
||||
t.Fatalf("export path = %q, want %q", registerRequest.Exports[0].Path, exportPath)
|
||||
}
|
||||
if len(registerRequest.Exports[0].Protocols) != 1 || registerRequest.Exports[0].Protocols[0] != "webdav" {
|
||||
t.Fatalf("export protocols = %#v, want [webdav]", registerRequest.Exports[0].Protocols)
|
||||
}
|
||||
if len(registerRequest.Exports[0].Tags) != 2 || registerRequest.Exports[0].Tags[0] != "runtime" || registerRequest.Exports[0].Tags[1] != "finder" {
|
||||
t.Fatalf("export tags = %#v, want [runtime finder]", registerRequest.Exports[0].Tags)
|
||||
}
|
||||
|
||||
heartbeatRequest := awaitValue(t, heartbeatRequests, 5*time.Second, "heartbeat request")
|
||||
if heartbeatRequest.NodeID != "runtime-node" {
|
||||
t.Fatalf("heartbeat node id = %q, want runtime-node", heartbeatRequest.NodeID)
|
||||
}
|
||||
if heartbeatRequest.Status != "online" {
|
||||
t.Fatalf("heartbeat status = %q, want online", heartbeatRequest.Status)
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, heartbeatRequest.LastSeenAt); err != nil {
|
||||
t.Fatalf("heartbeat lastSeenAt parse: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBinaryIntegratesWithRealControlPlane(t *testing.T) {
|
||||
repoRoot := testRepoRoot(t)
|
||||
nodeAgentBin := buildNodeAgentBinary(t, repoRoot)
|
||||
controlPlaneBin := buildControlPlaneBinary(t, repoRoot)
|
||||
nodeAgentPort := freePort(t)
|
||||
controlPlanePort := freePort(t)
|
||||
exportPath := filepath.Join(t.TempDir(), "export")
|
||||
if err := os.MkdirAll(exportPath, 0o755); err != nil {
|
||||
t.Fatalf("create export path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(exportPath, "seed.txt"), []byte("seed"), 0o644); err != nil {
|
||||
t.Fatalf("write seed file: %v", err)
|
||||
}
|
||||
|
||||
controlPlaneURL := "http://127.0.0.1:" + strconv.Itoa(controlPlanePort)
|
||||
nodeAgentURL := "http://127.0.0.1:" + strconv.Itoa(nodeAgentPort)
|
||||
controlPlane := startBinaryProcess(t, repoRoot, controlPlaneBin, []string{
|
||||
"PORT=" + strconv.Itoa(controlPlanePort),
|
||||
"BETTERNAS_VERSION=test-version",
|
||||
"BETTERNAS_EXAMPLE_MOUNT_URL=" + nodeAgentURL + "/dav/",
|
||||
"BETTERNAS_NODE_DIRECT_ADDRESS=" + nodeAgentURL,
|
||||
})
|
||||
defer controlPlane.stop(t)
|
||||
|
||||
waitForHTTPStatus(t, controlPlaneURL+"/health", http.StatusOK)
|
||||
|
||||
nodeAgent := startBinaryProcess(t, repoRoot, nodeAgentBin, []string{
|
||||
"PORT=" + strconv.Itoa(nodeAgentPort),
|
||||
"BETTERNAS_EXPORT_PATH=" + exportPath,
|
||||
"BETTERNAS_NODE_MACHINE_ID=runtime-machine",
|
||||
"BETTERNAS_NODE_DISPLAY_NAME=Runtime NAS",
|
||||
"BETTERNAS_EXPORT_LABEL=runtime-export",
|
||||
"BETTERNAS_NODE_DIRECT_ADDRESS=" + nodeAgentURL,
|
||||
"BETTERNAS_CONTROL_PLANE_URL=" + controlPlaneURL,
|
||||
"BETTERNAS_NODE_REGISTER_ENABLED=true",
|
||||
"BETTERNAS_NODE_HEARTBEAT_ENABLED=true",
|
||||
"BETTERNAS_NODE_HEARTBEAT_INTERVAL=100ms",
|
||||
})
|
||||
defer nodeAgent.stop(t)
|
||||
|
||||
waitForHTTPStatus(t, nodeAgentURL+"/health", http.StatusOK)
|
||||
waitForProcessOutput(t, nodeAgent, 5*time.Second, "registered as dev-node")
|
||||
waitForProcessOutput(t, nodeAgent, 5*time.Second, "stopping heartbeats")
|
||||
|
||||
mountProfileRequest, err := http.NewRequest(http.MethodPost, controlPlaneURL+"/api/v1/mount-profiles/issue", strings.NewReader(`{"userId":"integration-user","deviceId":"integration-device","exportId":"dev-export"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("build mount profile request: %v", err)
|
||||
}
|
||||
mountProfileRequest.Header.Set("Content-Type", "application/json")
|
||||
|
||||
mountProfileResponse, err := http.DefaultClient.Do(mountProfileRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("issue mount profile: %v", err)
|
||||
}
|
||||
defer mountProfileResponse.Body.Close()
|
||||
|
||||
if mountProfileResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("mount profile status = %d, want 200", mountProfileResponse.StatusCode)
|
||||
}
|
||||
|
||||
var mountProfile struct {
|
||||
Protocol string `json:"protocol"`
|
||||
MountURL string `json:"mountUrl"`
|
||||
}
|
||||
if err := json.NewDecoder(mountProfileResponse.Body).Decode(&mountProfile); err != nil {
|
||||
t.Fatalf("decode mount profile: %v", err)
|
||||
}
|
||||
|
||||
if mountProfile.Protocol != "webdav" {
|
||||
t.Fatalf("mount profile protocol = %q, want webdav", mountProfile.Protocol)
|
||||
}
|
||||
if mountProfile.MountURL != nodeAgentURL+"/dav/" {
|
||||
t.Fatalf("mount profile url = %q, want %q", mountProfile.MountURL, nodeAgentURL+"/dav/")
|
||||
}
|
||||
|
||||
propfindRequest, err := http.NewRequest("PROPFIND", mountProfile.MountURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build mount-url propfind request: %v", err)
|
||||
}
|
||||
propfindRequest.Header.Set("Depth", "0")
|
||||
|
||||
propfindResponse, err := http.DefaultClient.Do(propfindRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("propfind mount profile url: %v", err)
|
||||
}
|
||||
defer propfindResponse.Body.Close()
|
||||
|
||||
if propfindResponse.StatusCode != http.StatusMultiStatus {
|
||||
t.Fatalf("propfind status = %d, want 207", propfindResponse.StatusCode)
|
||||
}
|
||||
|
||||
getResponse, err := doRuntimeWebDAVRequest(nodeAgentURL, http.MethodGet, "/dav/seed.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("get WebDAV file after control-plane sync: %v", err)
|
||||
}
|
||||
defer getResponse.Body.Close()
|
||||
|
||||
getBody, err := io.ReadAll(getResponse.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read WebDAV body after control-plane sync: %v", err)
|
||||
}
|
||||
|
||||
if getResponse.StatusCode != http.StatusOK {
|
||||
t.Fatalf("get status after control-plane sync = %d, want 200", getResponse.StatusCode)
|
||||
}
|
||||
if string(getBody) != "seed" {
|
||||
t.Fatalf("get body after control-plane sync = %q, want seed", string(getBody))
|
||||
}
|
||||
}
|
||||
|
||||
func testRepoRoot(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("resolve runtime integration test filename")
|
||||
}
|
||||
|
||||
return filepath.Clean(filepath.Join(filepath.Dir(filename), "..", "..", "..", ".."))
|
||||
}
|
||||
|
||||
func buildNodeAgentBinary(t *testing.T, repoRoot string) string {
|
||||
t.Helper()
|
||||
|
||||
binDir := t.TempDir()
|
||||
nodeAgentBin := filepath.Join(binDir, binaryName("node-agent"))
|
||||
buildBinary(t, repoRoot, "./apps/node-agent/cmd/node-agent", nodeAgentBin)
|
||||
return nodeAgentBin
|
||||
}
|
||||
|
||||
func buildControlPlaneBinary(t *testing.T, repoRoot string) string {
|
||||
t.Helper()
|
||||
|
||||
binDir := t.TempDir()
|
||||
controlPlaneBin := filepath.Join(binDir, binaryName("control-plane"))
|
||||
buildBinary(t, repoRoot, "./apps/control-plane/cmd/control-plane", controlPlaneBin)
|
||||
return controlPlaneBin
|
||||
}
|
||||
|
||||
func binaryName(base string) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return base + ".exe"
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
func buildBinary(t *testing.T, repoRoot, packagePath, outputPath string) {
|
||||
t.Helper()
|
||||
|
||||
command := exec.Command("go", "build", "-o", outputPath, packagePath)
|
||||
command.Dir = repoRoot
|
||||
command.Env = mergedEnv([]string{"CGO_ENABLED=0"})
|
||||
output, err := command.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("build %s: %v\n%s", packagePath, err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func startBinaryProcess(t *testing.T, repoRoot, binaryPath string, env []string) *startedProcess {
|
||||
t.Helper()
|
||||
|
||||
output := &lockedBuffer{}
|
||||
command := exec.Command(binaryPath)
|
||||
command.Dir = repoRoot
|
||||
command.Env = mergedEnv(env)
|
||||
command.Stdout = output
|
||||
command.Stderr = output
|
||||
|
||||
if err := command.Start(); err != nil {
|
||||
t.Fatalf("start %s: %v", binaryPath, err)
|
||||
}
|
||||
|
||||
return &startedProcess{
|
||||
cmd: command,
|
||||
output: output,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *startedProcess) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if p == nil || p.cmd == nil || p.cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = p.cmd.Process.Kill()
|
||||
|
||||
waitDone := make(chan error, 1)
|
||||
go func() {
|
||||
waitDone <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-waitDone:
|
||||
if err != nil && !strings.Contains(err.Error(), "signal: killed") {
|
||||
t.Fatalf("wait for %s: %v\n%s", p.cmd.Path, err, p.output.String())
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
_ = p.cmd.Process.Kill()
|
||||
err := <-waitDone
|
||||
if err != nil && !strings.Contains(err.Error(), "signal: killed") {
|
||||
t.Fatalf("kill %s: %v\n%s", p.cmd.Path, err, p.output.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func freePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for free port: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func waitForHTTPStatus(t *testing.T, target string, wantStatus int) {
|
||||
t.Helper()
|
||||
|
||||
waitForCondition(t, 10*time.Second, target, func() bool {
|
||||
response, err := http.Get(target)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
return response.StatusCode == wantStatus
|
||||
})
|
||||
}
|
||||
|
||||
func waitForProcessOutput(t *testing.T, process *startedProcess, timeout time.Duration, fragment string) {
|
||||
t.Helper()
|
||||
|
||||
waitForCondition(t, timeout, "process output "+fragment, func() bool {
|
||||
return strings.Contains(process.output.String(), fragment)
|
||||
})
|
||||
}
|
||||
|
||||
func doRuntimeWebDAVRequest(baseURL, method, requestPath string, body io.Reader) (*http.Response, error) {
|
||||
request, err := http.NewRequest(method, baseURL+requestPath, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return http.DefaultClient.Do(request)
|
||||
}
|
||||
|
||||
func firstNonLoopbackIPv4() (string, bool) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch value := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = value.IP
|
||||
case *net.IPAddr:
|
||||
ip = value.IP
|
||||
}
|
||||
|
||||
if ip == nil || ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return ipv4.String(), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func mergedEnv(overrides []string) []string {
|
||||
values := make(map[string]string)
|
||||
|
||||
for _, entry := range os.Environ() {
|
||||
key, value, ok := strings.Cut(entry, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
values[key] = value
|
||||
}
|
||||
|
||||
for _, entry := range overrides {
|
||||
key, value, ok := strings.Cut(entry, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
values[key] = value
|
||||
}
|
||||
|
||||
merged := make([]string, 0, len(values))
|
||||
for key, value := range values {
|
||||
merged = append(merged, key+"="+value)
|
||||
}
|
||||
sort.Strings(merged)
|
||||
return merged
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue