From 500354cd9bd1ca356fc44f003b41dd9535bcebd4 Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Thu, 9 Apr 2026 17:52:14 +0000 Subject: [PATCH] fix: address gateway review findings --- contract/machines.go | 26 +-- contract/networking.go | 1 + internal/config/config.go | 2 + internal/daemon/create.go | 55 ++++-- internal/daemon/daemon.go | 7 + internal/daemon/daemon_test.go | 10 ++ internal/daemon/files.go | 56 +++++-- internal/daemon/guest_hostkey.go | 51 ++++++ internal/daemon/lifecycle.go | 25 +++ internal/daemon/machine_relays.go | 184 +++++++++++++++++++++ internal/daemon/published_ports.go | 8 +- internal/daemon/review_regressions_test.go | 6 + internal/daemon/snapshot.go | 49 ++++-- internal/model/types.go | 27 +-- 14 files changed, 441 insertions(+), 66 deletions(-) create mode 100644 internal/daemon/guest_hostkey.go create mode 100644 internal/daemon/machine_relays.go diff --git a/contract/machines.go b/contract/machines.go index 09d1650..43c7432 100644 --- a/contract/machines.go +++ b/contract/machines.go @@ -3,21 +3,23 @@ package host import "time" type Machine struct { - ID MachineID `json:"id"` - Artifact ArtifactRef `json:"artifact"` - SystemVolumeID VolumeID `json:"system_volume_id,omitempty"` - UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"` - RuntimeHost string `json:"runtime_host,omitempty"` - Ports []MachinePort `json:"ports,omitempty"` - Phase MachinePhase `json:"phase"` - Error string `json:"error,omitempty"` - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` + ID MachineID `json:"id"` + Artifact ArtifactRef `json:"artifact"` + SystemVolumeID VolumeID `json:"system_volume_id,omitempty"` + UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"` + RuntimeHost string `json:"runtime_host,omitempty"` + Ports []MachinePort `json:"ports,omitempty"` + GuestSSHPublicKey string `json:"guest_ssh_host_public_key,omitempty"` + Phase MachinePhase `json:"phase"` + Error string `json:"error,omitempty"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` } type GuestConfig struct { - AuthorizedKeys []string `json:"authorized_keys,omitempty"` - LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"` + AuthorizedKeys []string `json:"authorized_keys,omitempty"` + TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"` + LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"` } type GuestLoginWebhook struct { diff --git a/contract/networking.go b/contract/networking.go index b347beb..3750b1a 100644 --- a/contract/networking.go +++ b/contract/networking.go @@ -16,5 +16,6 @@ const ( type MachinePort struct { Name MachinePortName `json:"name"` Port uint16 `json:"port"` + HostPort uint16 `json:"host_port,omitempty"` Protocol PortProtocol `json:"protocol"` } diff --git a/internal/config/config.go b/internal/config/config.go index 7d64a09..b76e8ad 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,7 @@ type Config struct { EgressInterface string FirecrackerBinaryPath string JailerBinaryPath string + GuestLoginCAPublicKey string } // Load loads and validates the firecracker-host daemon configuration from the environment. @@ -43,6 +44,7 @@ func Load() (Config, error) { EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")), FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")), JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")), + GuestLoginCAPublicKey: strings.TrimSpace(os.Getenv("GUEST_LOGIN_CA_PUBLIC_KEY")), } if err := cfg.Validate(); err != nil { return Config{}, err diff --git a/internal/daemon/create.go b/internal/daemon/create.go index 2a5b537..33e89c5 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -103,6 +103,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi _ = d.runtime.Delete(context.Background(), *state) return nil, err } + guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) + if err != nil { + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } now := time.Now().UTC() systemVolumeRecord := model.VolumeRecord{ @@ -138,19 +143,44 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi } record := model.MachineRecord{ - ID: req.MachineID, - Artifact: req.Artifact, - SystemVolumeID: systemVolumeRecord.ID, - UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), - RuntimeHost: state.RuntimeHost, - TapDevice: state.TapName, - Ports: ports, - Phase: contracthost.MachinePhaseRunning, - PID: state.PID, - SocketPath: state.SocketPath, - CreatedAt: now, - StartedAt: state.StartedAt, + ID: req.MachineID, + Artifact: req.Artifact, + SystemVolumeID: systemVolumeRecord.ID, + UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), + RuntimeHost: state.RuntimeHost, + TapDevice: state.TapName, + Ports: ports, + GuestSSHPublicKey: guestSSHPublicKey, + Phase: contracthost.MachinePhaseRunning, + PID: state.PID, + SocketPath: state.SocketPath, + CreatedAt: now, + StartedAt: state.StartedAt, } + d.relayAllocMu.Lock() + sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort) + var vncRelayPort uint16 + if err == nil { + vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort) + } + d.relayAllocMu.Unlock() + if err != nil { + d.stopMachineRelays(record.ID) + for _, volume := range userVolumes { + volume.AttachedMachineID = nil + _ = d.store.UpdateVolume(context.Background(), volume) + } + _ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID) + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) + startedRelays := true + defer func() { + if startedRelays { + d.stopMachineRelays(record.ID) + } + }() if err := d.store.CreateMachine(ctx, record); err != nil { for _, volume := range userVolumes { volume.AttachedMachineID = nil @@ -162,6 +192,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi } removeSystemVolumeOnFailure = false + startedRelays = false clearOperation = true return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index a420ef4..11c7273 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -42,11 +42,15 @@ type Daemon struct { runtime Runtime reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error + readGuestSSHPublicKey func(context.Context, string) (string, error) locksMu sync.Mutex machineLocks map[contracthost.MachineID]*sync.Mutex artifactLocks map[string]*sync.Mutex + relayAllocMu sync.Mutex + machineRelaysMu sync.Mutex + machineRelayListeners map[string]net.Listener publishedPortAllocMu sync.Mutex publishedPortsMu sync.Mutex publishedPortListeners map[contracthost.PublishedPortID]net.Listener @@ -72,11 +76,14 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err store: store, runtime: runtime, reconfigureGuestIdentity: nil, + readGuestSSHPublicKey: nil, machineLocks: make(map[contracthost.MachineID]*sync.Mutex), artifactLocks: make(map[string]*sync.Mutex), + machineRelayListeners: make(map[string]net.Listener), publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener), } daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH + daemon.readGuestSSHPublicKey = readGuestSSHPublicKey if err := daemon.ensureBackendSSHKeyPair(); err != nil { return nil, err } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 3d33345..6b26694 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -98,6 +98,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) kernelPayload := []byte("kernel-image") rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4") @@ -261,6 +262,7 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil } artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} @@ -367,6 +369,7 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) var reconfiguredHost string var reconfiguredMachine contracthost.MachineID hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID) error { @@ -508,6 +511,7 @@ func TestDeleteMachineMissingIsNoOp(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil { t.Fatalf("delete missing machine: %v", err) @@ -533,6 +537,12 @@ func testConfig(root string) appconfig.Config { } } +func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) { + hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) { + return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil + } +} + func listenTestPort(t *testing.T, port int) net.Listener { t.Helper() diff --git a/internal/daemon/files.go b/internal/daemon/files.go index b518f54..9c6b9e5 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -187,9 +187,13 @@ func isZeroChunk(chunk []byte) bool { } func defaultMachinePorts() []contracthost.MachinePort { + return buildMachinePorts(0, 0) +} + +func buildMachinePorts(sshRelayPort, vncRelayPort uint16) []contracthost.MachinePort { return []contracthost.MachinePort{ - {Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, Protocol: contracthost.PortProtocolTCP}, - {Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, Protocol: contracthost.PortProtocolTCP}, + {Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP}, + {Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP}, } } @@ -247,7 +251,14 @@ func (d *Daemon) mergedGuestConfig(config *contracthost.GuestConfig) (*contracth } merged := &contracthost.GuestConfig{ - AuthorizedKeys: authorizedKeys, + AuthorizedKeys: authorizedKeys, + TrustedUserCAKeys: nil, + } + if strings.TrimSpace(d.config.GuestLoginCAPublicKey) != "" { + merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, d.config.GuestLoginCAPublicKey) + } + if config != nil { + merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, config.TrustedUserCAKeys...) } if config != nil && config.LoginWebhook != nil { loginWebhook := *config.LoginWebhook @@ -260,7 +271,7 @@ func hasGuestConfig(config *contracthost.GuestConfig) bool { if config == nil { return false } - return len(config.AuthorizedKeys) > 0 || config.LoginWebhook != nil + return len(config.AuthorizedKeys) > 0 || len(config.TrustedUserCAKeys) > 0 || config.LoginWebhook != nil } func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error { @@ -286,6 +297,17 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho } } + if len(config.TrustedUserCAKeys) > 0 { + trustedCAPath := filepath.Join(stagingDir, "trusted_user_ca_keys") + payload := []byte(strings.Join(config.TrustedUserCAKeys, "\n") + "\n") + if err := os.WriteFile(trustedCAPath, payload, 0o644); err != nil { + return fmt.Errorf("write trusted_user_ca_keys staging file: %w", err) + } + if err := replaceExt4File(ctx, imagePath, trustedCAPath, "/etc/microagent/trusted_user_ca_keys"); err != nil { + return err + } + } + if config.LoginWebhook != nil { guestConfigPath := filepath.Join(stagingDir, "guest-config.json") payload, err := json.Marshal(config) @@ -363,16 +385,17 @@ func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID { func machineToContract(record model.MachineRecord) contracthost.Machine { return contracthost.Machine{ - ID: record.ID, - Artifact: record.Artifact, - SystemVolumeID: record.SystemVolumeID, - UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...), - RuntimeHost: record.RuntimeHost, - Ports: append([]contracthost.MachinePort(nil), record.Ports...), - Phase: record.Phase, - Error: record.Error, - CreatedAt: record.CreatedAt, - StartedAt: record.StartedAt, + ID: record.ID, + Artifact: record.Artifact, + SystemVolumeID: record.SystemVolumeID, + UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...), + RuntimeHost: record.RuntimeHost, + Ports: append([]contracthost.MachinePort(nil), record.Ports...), + GuestSSHPublicKey: record.GuestSSHPublicKey, + Phase: record.Phase, + Error: record.Error, + CreatedAt: record.CreatedAt, + StartedAt: record.StartedAt, } } @@ -427,6 +450,11 @@ func validateGuestConfig(config *contracthost.GuestConfig) error { return fmt.Errorf("guest_config.authorized_keys[%d] is required", i) } } + for i, key := range config.TrustedUserCAKeys { + if strings.TrimSpace(key) == "" { + return fmt.Errorf("guest_config.trusted_user_ca_keys[%d] is required", i) + } + } if config.LoginWebhook != nil { if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil { return err diff --git a/internal/daemon/guest_hostkey.go b/internal/daemon/guest_hostkey.go new file mode 100644 index 0000000..7bb6ac5 --- /dev/null +++ b/internal/daemon/guest_hostkey.go @@ -0,0 +1,51 @@ +package daemon + +import ( + "context" + "fmt" + "net" + "strconv" + "strings" + + "golang.org/x/crypto/ssh" +) + +func readGuestSSHPublicKey(ctx context.Context, runtimeHost string) (string, error) { + host := strings.TrimSpace(runtimeHost) + if host == "" { + return "", fmt.Errorf("runtime host is required") + } + + probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout) + defer cancel() + + targetAddr := net.JoinHostPort(host, strconv.Itoa(int(defaultSSHPort))) + netConn, err := (&net.Dialer{}).DialContext(probeCtx, "tcp", targetAddr) + if err != nil { + return "", fmt.Errorf("dial guest ssh for host key: %w", err) + } + defer func() { + _ = netConn.Close() + }() + + var captured ssh.PublicKey + clientConfig := &ssh.ClientConfig{ + User: "host-key-probe", + Auth: []ssh.AuthMethod{ssh.Password("invalid")}, + HostKeyAlgorithms: []string{ssh.KeyAlgoED25519}, + HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error { + captured = key + return fmt.Errorf("guest ssh host key captured") + }, + Timeout: defaultGuestDialTimeout, + ClientVersion: "SSH-2.0-agentcomputer-firecracker-host", + } + _, _, _, err = ssh.NewClientConn(netConn, targetAddr, clientConfig) + if captured == nil { + if err == nil { + return "", fmt.Errorf("guest ssh host key probe returned without a host key") + } + return "", fmt.Errorf("handshake guest ssh for host key: %w", err) + } + return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(captured))), nil +} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index e55e56a..dcfffb7 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -99,10 +99,16 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* _ = d.runtime.Delete(context.Background(), *state) return nil, err } + guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) + if err != nil { + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } record.RuntimeHost = state.RuntimeHost record.TapDevice = state.TapName record.Ports = ports + record.GuestSSHPublicKey = guestSSHPublicKey record.Phase = contracthost.MachinePhaseRunning record.Error = "" record.PID = state.PID @@ -112,7 +118,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* _ = d.runtime.Delete(context.Background(), *state) return nil, err } + if err := d.ensureMachineRelays(ctx, record); err != nil { + d.stopMachineRelays(id) + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { + d.stopMachineRelays(id) d.stopPublishedPortsForMachine(id) _ = d.runtime.Delete(context.Background(), *state) return nil, err @@ -238,10 +250,14 @@ func (d *Daemon) Reconcile(ctx context.Context) error { return err } if reconciled.Phase == contracthost.MachinePhaseRunning { + if err := d.ensureMachineRelays(ctx, reconciled); err != nil { + return err + } if err := d.ensurePublishedPortsForMachine(ctx, *reconciled); err != nil { return err } } else { + d.stopMachineRelays(reconciled.ID) d.stopPublishedPortsForMachine(reconciled.ID) } } @@ -322,6 +338,9 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach return err } if record.Phase == contracthost.MachinePhaseRunning { + if err := d.ensureMachineRelays(ctx, record); err != nil { + return err + } if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { return err } @@ -375,12 +394,16 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma return nil, err } if state.Phase == firecracker.PhaseRunning { + if err := d.ensureMachineRelays(ctx, record); err != nil { + return nil, err + } return record, nil } if err := d.runtime.Delete(ctx, *state); err != nil { return nil, err } + d.stopMachineRelays(record.ID) d.stopPublishedPortsForMachine(record.ID) record.Phase = contracthost.MachinePhaseFailed record.Error = state.Error @@ -396,6 +419,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma } func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error { + d.stopMachineRelays(record.ID) d.stopPublishedPortsForMachine(record.ID) if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { return err @@ -426,6 +450,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR } func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error { + d.stopMachineRelays(record.ID) d.stopPublishedPortsForMachine(record.ID) if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { return err diff --git a/internal/daemon/machine_relays.go b/internal/daemon/machine_relays.go new file mode 100644 index 0000000..91fc499 --- /dev/null +++ b/internal/daemon/machine_relays.go @@ -0,0 +1,184 @@ +package daemon + +import ( + "context" + "fmt" + "net" + "strconv" + "strings" + + "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +const ( + minMachineSSHRelayPort = uint16(40000) + maxMachineSSHRelayPort = uint16(44999) + minMachineVNCRelayPort = uint16(45000) + maxMachineVNCRelayPort = uint16(49999) +) + +func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string { + return string(machineID) + ":" + string(name) +} + +func machineRelayHostPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 { + for _, port := range record.Ports { + if port.Name == name { + return port.HostPort + } + } + return 0 +} + +func machineRelayGuestPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 { + for _, port := range record.Ports { + if port.Name == name { + return port.Port + } + } + switch name { + case contracthost.MachinePortNameVNC: + return defaultVNCPort + default: + return defaultSSHPort + } +} + +func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contracthost.MachineID, name contracthost.MachinePortName) (map[uint16]struct{}, error) { + records, err := d.store.ListMachines(ctx) + if err != nil { + return nil, err + } + + used := make(map[uint16]struct{}, len(records)) + for _, record := range records { + if record.ID == machineID { + continue + } + if port := machineRelayHostPort(record, name); port != 0 { + used[port] = struct{}{} + } + } + return used, nil +} + +func (d *Daemon) allocateMachineRelayProxy( + ctx context.Context, + current model.MachineRecord, + name contracthost.MachinePortName, + runtimeHost string, + guestPort uint16, + minPort uint16, + maxPort uint16, +) (uint16, error) { + existingPort := machineRelayHostPort(current, name) + if existingPort != 0 { + if err := d.startMachineRelayProxy(current.ID, name, existingPort, runtimeHost, guestPort); err == nil { + return existingPort, nil + } else if !isAddrInUseError(err) { + return 0, err + } + } + + used, err := d.usedMachineRelayPorts(ctx, current.ID, name) + if err != nil { + return 0, err + } + if existingPort != 0 { + used[existingPort] = struct{}{} + } + for hostPort := minPort; hostPort <= maxPort; hostPort++ { + if _, exists := used[hostPort]; exists { + continue + } + if err := d.startMachineRelayProxy(current.ID, name, hostPort, runtimeHost, guestPort); err != nil { + if isAddrInUseError(err) { + continue + } + return 0, err + } + return hostPort, nil + } + return 0, fmt.Errorf("no relay ports are available in range %d-%d", minPort, maxPort) +} + +func (d *Daemon) ensureMachineRelays(ctx context.Context, record *model.MachineRecord) error { + if record == nil { + return fmt.Errorf("machine record is required") + } + if record.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(record.RuntimeHost) == "" { + return nil + } + + d.relayAllocMu.Lock() + sshRelayPort, err := d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameSSH, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameSSH), minMachineSSHRelayPort, maxMachineSSHRelayPort) + var vncRelayPort uint16 + if err == nil { + vncRelayPort, err = d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameVNC, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameVNC), minMachineVNCRelayPort, maxMachineVNCRelayPort) + } + d.relayAllocMu.Unlock() + if err != nil { + d.stopMachineRelays(record.ID) + return err + } + + record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) + if err := d.store.UpdateMachine(ctx, *record); err != nil { + d.stopMachineRelays(record.ID) + return err + } + return nil +} + +func (d *Daemon) startMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName, hostPort uint16, runtimeHost string, guestPort uint16) error { + targetHost := strings.TrimSpace(runtimeHost) + if targetHost == "" { + return fmt.Errorf("runtime host is required for machine relay %q", machineID) + } + + key := machineRelayListenerKey(machineID, name) + + d.machineRelaysMu.Lock() + if _, exists := d.machineRelayListeners[key]; exists { + d.machineRelaysMu.Unlock() + return nil + } + listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(hostPort))) + if err != nil { + d.machineRelaysMu.Unlock() + return fmt.Errorf("listen on machine relay port %d: %w", hostPort, err) + } + d.machineRelayListeners[key] = listener + d.machineRelaysMu.Unlock() + + targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(guestPort))) + go serveTCPProxy(listener, targetAddr) + return nil +} + +func (d *Daemon) stopMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName) { + key := machineRelayListenerKey(machineID, name) + + d.machineRelaysMu.Lock() + listener, ok := d.machineRelayListeners[key] + if ok { + delete(d.machineRelayListeners, key) + } + d.machineRelaysMu.Unlock() + if ok { + _ = listener.Close() + } +} + +func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) { + d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH) + d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC) +} + +func isAddrInUseError(err error) bool { + if err == nil { + return false + } + return strings.Contains(strings.ToLower(err.Error()), "address already in use") +} diff --git a/internal/daemon/published_ports.go b/internal/daemon/published_ports.go index f91e71b..654a42b 100644 --- a/internal/daemon/published_ports.go +++ b/internal/daemon/published_ports.go @@ -177,11 +177,11 @@ func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtime d.publishedPortsMu.Unlock() targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port))) - go d.servePublishedPortProxy(port.ID, listener, targetAddr) + go serveTCPProxy(listener, targetAddr) return nil } -func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) { +func serveTCPProxy(listener net.Listener, targetAddr string) { for { conn, err := listener.Accept() if err != nil { @@ -190,11 +190,11 @@ func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, li } continue } - go proxyPublishedPortConnection(conn, targetAddr) + go proxyTCPConnection(conn, targetAddr) } } -func proxyPublishedPortConnection(source net.Conn, targetAddr string) { +func proxyTCPConnection(source net.Conn, targetAddr string) { defer func() { _ = source.Close() }() diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 08c7082..6d3bf2f 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -86,6 +86,7 @@ func TestCreatePublishedPortSerializesHostPortAllocationAcrossMachines(t *testin if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) for _, machineID := range []contracthost.MachineID{"vm-1", "vm-2"} { if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{ @@ -170,6 +171,7 @@ func TestGetStorageReportHandlesSparseSnapshotPathsAndIncludesPublishedPortPool( if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) response, err := hostDaemon.GetStorageReport(context.Background()) if err != nil { @@ -209,6 +211,7 @@ func TestReconcileSnapshotPreservesArtifactsOnUnexpectedStoreError(t *testing.T) if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) snapshotID := contracthost.SnapshotID("snap-1") operation := model.OperationRecord{ @@ -249,6 +252,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) operation := model.OperationRecord{ MachineID: "vm-1", @@ -291,6 +295,7 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) machineID := contracthost.MachineID("vm-1") snapshotID := contracthost.SnapshotID("snap-1") @@ -346,6 +351,7 @@ func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) t.Cleanup(func() { hostDaemon.stopPublishedPortProxy("port-1") diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index ac92944..d8329ac 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -260,6 +260,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn clearOperation = true return nil, fmt.Errorf("reconfigure restored guest identity: %w", err) } + guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost) + if err != nil { + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("read restored guest ssh host key: %w", err) + } systemVolumeID := d.systemVolumeID(req.MachineID) now := time.Now().UTC() @@ -277,22 +284,42 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } machineRecord := model.MachineRecord{ - ID: req.MachineID, - Artifact: snap.Artifact, - SystemVolumeID: systemVolumeID, - RuntimeHost: machineState.RuntimeHost, - TapDevice: machineState.TapName, - Ports: defaultMachinePorts(), - Phase: contracthost.MachinePhaseRunning, - PID: machineState.PID, - SocketPath: machineState.SocketPath, - CreatedAt: now, - StartedAt: machineState.StartedAt, + ID: req.MachineID, + Artifact: snap.Artifact, + SystemVolumeID: systemVolumeID, + RuntimeHost: machineState.RuntimeHost, + TapDevice: machineState.TapName, + Ports: defaultMachinePorts(), + GuestSSHPublicKey: guestSSHPublicKey, + Phase: contracthost.MachinePhaseRunning, + PID: machineState.PID, + SocketPath: machineState.SocketPath, + CreatedAt: now, + StartedAt: machineState.StartedAt, } + d.relayAllocMu.Lock() + sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort) + var vncRelayPort uint16 + if err == nil { + vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort) + } + d.relayAllocMu.Unlock() + if err != nil { + d.stopMachineRelays(machineRecord.ID) + return nil, err + } + machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) + startedRelays := true + defer func() { + if startedRelays { + d.stopMachineRelays(machineRecord.ID) + } + }() if err := d.store.CreateMachine(ctx, machineRecord); err != nil { return nil, err } + startedRelays = false clearOperation = true return &contracthost.RestoreSnapshotResponse{ Machine: machineToContract(machineRecord), diff --git a/internal/model/types.go b/internal/model/types.go index 5845767..e93b625 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -27,19 +27,20 @@ type ArtifactRecord struct { } type MachineRecord struct { - ID contracthost.MachineID - Artifact contracthost.ArtifactRef - SystemVolumeID contracthost.VolumeID - UserVolumeIDs []contracthost.VolumeID - RuntimeHost string - TapDevice string - Ports []contracthost.MachinePort - Phase contracthost.MachinePhase - Error string - PID int - SocketPath string - CreatedAt time.Time - StartedAt *time.Time + ID contracthost.MachineID + Artifact contracthost.ArtifactRef + SystemVolumeID contracthost.VolumeID + UserVolumeIDs []contracthost.VolumeID + RuntimeHost string + TapDevice string + Ports []contracthost.MachinePort + GuestSSHPublicKey string + Phase contracthost.MachinePhase + Error string + PID int + SocketPath string + CreatedAt time.Time + StartedAt *time.Time } type VolumeRecord struct {