diff --git a/internal/daemon/create.go b/internal/daemon/create.go index 8907495..9aaeeaa 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -23,6 +23,10 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi if err := validateGuestConfig(req.GuestConfig); err != nil { return nil, err } + guestConfig, err := d.mergedGuestConfig(req.GuestConfig) + if err != nil { + return nil, err + } unlock := d.lockMachine(req.MachineID) defer unlock() @@ -73,7 +77,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi _ = os.Remove(systemVolumePath) _ = os.RemoveAll(filepath.Dir(systemVolumePath)) }() - if err := injectGuestConfig(ctx, systemVolumePath, req.GuestConfig); err != nil { + if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil { return nil, err } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 80aa399..f955306 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -56,13 +56,17 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err return nil, fmt.Errorf("create daemon dir %q: %w", dir, err) } } - return &Daemon{ + daemon := &Daemon{ config: cfg, store: store, runtime: runtime, machineLocks: make(map[contracthost.MachineID]*sync.Mutex), artifactLocks: make(map[string]*sync.Mutex), - }, nil + } + if err := daemon.ensureBackendSSHKeyPair(); err != nil { + return nil, err + } + return daemon, nil } func (d *Daemon) Health(ctx context.Context) (*contracthost.HealthResponse, error) { diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index d28abe0..384f149 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -75,7 +75,14 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { } kernelPayload := []byte("kernel-image") - rootFSPayload := []byte("rootfs-image") + rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4") + if err := buildTestExt4Image(root, rootFSImagePath); err != nil { + t.Fatalf("build ext4 image: %v", err) + } + rootFSPayload, err := os.ReadFile(rootFSImagePath) + if err != nil { + t.Fatalf("read ext4 image: %v", err) + } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/kernel": @@ -94,6 +101,15 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }, + GuestConfig: &contracthost.GuestConfig{ + AuthorizedKeys: []string{ + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestOverrideKey daemon-test", + }, + LoginWebhook: &contracthost.GuestLoginWebhook{ + URL: "https://example.com/login", + BearerToken: "token", + }, + }, }) if err != nil { t.Fatalf("create machine: %v", err) @@ -123,6 +139,20 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if _, err := os.Stat(runtime.lastSpec.RootFSPath); err != nil { t.Fatalf("system disk not staged: %v", err) } + hostAuthorizedKeyBytes, err := os.ReadFile(hostDaemon.backendSSHPublicKeyPath()) + if err != nil { + t.Fatalf("read backend ssh public key: %v", err) + } + authorizedKeys, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/authorized_keys") + if err != nil { + t.Fatalf("read injected authorized_keys: %v", err) + } + if !strings.Contains(authorizedKeys, strings.TrimSpace(string(hostAuthorizedKeyBytes))) { + t.Fatalf("authorized_keys missing backend ssh key: %q", authorizedKeys) + } + if !strings.Contains(authorizedKeys, "daemon-test") { + t.Fatalf("authorized_keys missing request override key: %q", authorizedKeys) + } artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact) if err != nil { @@ -154,6 +184,31 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { } } +func TestNewEnsuresBackendSSHKeyPair(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, fileStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + if _, err := os.Stat(hostDaemon.backendSSHPrivateKeyPath()); err != nil { + t.Fatalf("stat backend ssh private key: %v", err) + } + publicKeyPayload, err := os.ReadFile(hostDaemon.backendSSHPublicKeyPath()) + if err != nil { + t.Fatalf("read backend ssh public key: %v", err) + } + if !strings.HasPrefix(string(publicKeyPayload), "ssh-ed25519 ") { + t.Fatalf("unexpected backend ssh public key: %q", string(publicKeyPayload)) + } +} + func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) { t.Parallel() diff --git a/internal/daemon/files.go b/internal/daemon/files.go index a5c5f35..1d3f164 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -31,6 +31,14 @@ func (d *Daemon) machineRuntimeBaseDir(machineID contracthost.MachineID) string return filepath.Join(d.config.RuntimeDir, "machines", string(machineID)) } +func (d *Daemon) backendSSHPrivateKeyPath() string { + return filepath.Join(d.config.RootDir, "state", "ssh", "backend_ed25519") +} + +func (d *Daemon) backendSSHPublicKeyPath() string { + return d.backendSSHPrivateKeyPath() + ".pub" +} + func artifactKey(ref contracthost.ArtifactRef) string { sum := sha256.Sum256([]byte(ref.KernelImageURL + "\n" + ref.RootFSURL)) return hex.EncodeToString(sum[:]) @@ -181,6 +189,69 @@ func defaultMachinePorts() []contracthost.MachinePort { } } +func (d *Daemon) ensureBackendSSHKeyPair() error { + privateKeyPath := d.backendSSHPrivateKeyPath() + publicKeyPath := d.backendSSHPublicKeyPath() + if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0o700); err != nil { + return fmt.Errorf("create backend ssh dir: %w", err) + } + privateExists := fileExists(privateKeyPath) + publicExists := fileExists(publicKeyPath) + switch { + case privateExists && publicExists: + return nil + case privateExists && !publicExists: + return d.writeBackendSSHPublicKey(privateKeyPath, publicKeyPath) + case !privateExists && publicExists: + return fmt.Errorf("backend ssh private key %q is missing while public key exists", privateKeyPath) + } + + command := exec.Command("ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", privateKeyPath) + output, err := command.CombinedOutput() + if err != nil { + return fmt.Errorf("generate backend ssh keypair: %w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func (d *Daemon) writeBackendSSHPublicKey(privateKeyPath string, publicKeyPath string) error { + command := exec.Command("ssh-keygen", "-y", "-f", privateKeyPath) + output, err := command.CombinedOutput() + if err != nil { + return fmt.Errorf("derive backend ssh public key: %w: %s", err, strings.TrimSpace(string(output))) + } + payload := strings.TrimSpace(string(output)) + "\n" + if err := os.WriteFile(publicKeyPath, []byte(payload), 0o644); err != nil { + return fmt.Errorf("write backend ssh public key %q: %w", publicKeyPath, err) + } + return nil +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func (d *Daemon) mergedGuestConfig(config *contracthost.GuestConfig) (*contracthost.GuestConfig, error) { + hostAuthorizedKey, err := os.ReadFile(d.backendSSHPublicKeyPath()) + if err != nil { + return nil, fmt.Errorf("read backend ssh public key: %w", err) + } + authorizedKeys := []string{strings.TrimSpace(string(hostAuthorizedKey))} + if config != nil { + authorizedKeys = append(authorizedKeys, config.AuthorizedKeys...) + } + + merged := &contracthost.GuestConfig{ + AuthorizedKeys: authorizedKeys, + } + if config != nil && config.LoginWebhook != nil { + loginWebhook := *config.LoginWebhook + merged.LoginWebhook = &loginWebhook + } + return merged, nil +} + func hasGuestConfig(config *contracthost.GuestConfig) bool { if config == nil { return false