fix: address gateway review findings

This commit is contained in:
Harivansh Rathi 2026-04-09 17:52:14 +00:00
parent 59d3290bb9
commit 500354cd9b
14 changed files with 441 additions and 66 deletions

View file

@ -103,6 +103,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
now := time.Now().UTC()
systemVolumeRecord := model.VolumeRecord{
@ -138,19 +143,44 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
}
record := model.MachineRecord{
ID: req.MachineID,
Artifact: req.Artifact,
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: ports,
Phase: contracthost.MachinePhaseRunning,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
ID: req.MachineID,
Artifact: req.Artifact,
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: ports,
GuestSSHPublicKey: guestSSHPublicKey,
Phase: contracthost.MachinePhaseRunning,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
}
d.relayAllocMu.Lock()
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
var vncRelayPort uint16
if err == nil {
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
}
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelays(record.ID)
for _, volume := range userVolumes {
volume.AttachedMachineID = nil
_ = d.store.UpdateVolume(context.Background(), volume)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
startedRelays := true
defer func() {
if startedRelays {
d.stopMachineRelays(record.ID)
}
}()
if err := d.store.CreateMachine(ctx, record); err != nil {
for _, volume := range userVolumes {
volume.AttachedMachineID = nil
@ -162,6 +192,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
}
removeSystemVolumeOnFailure = false
startedRelays = false
clearOperation = true
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
}

View file

@ -42,11 +42,15 @@ type Daemon struct {
runtime Runtime
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error
readGuestSSHPublicKey func(context.Context, string) (string, error)
locksMu sync.Mutex
machineLocks map[contracthost.MachineID]*sync.Mutex
artifactLocks map[string]*sync.Mutex
relayAllocMu sync.Mutex
machineRelaysMu sync.Mutex
machineRelayListeners map[string]net.Listener
publishedPortAllocMu sync.Mutex
publishedPortsMu sync.Mutex
publishedPortListeners map[contracthost.PublishedPortID]net.Listener
@ -72,11 +76,14 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
store: store,
runtime: runtime,
reconfigureGuestIdentity: nil,
readGuestSSHPublicKey: nil,
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
artifactLocks: make(map[string]*sync.Mutex),
machineRelayListeners: make(map[string]net.Listener),
publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener),
}
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
return nil, err
}

View file

@ -98,6 +98,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
kernelPayload := []byte("kernel-image")
rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4")
@ -261,6 +262,7 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
@ -367,6 +369,7 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
var reconfiguredHost string
var reconfiguredMachine contracthost.MachineID
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID) error {
@ -508,6 +511,7 @@ func TestDeleteMachineMissingIsNoOp(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil {
t.Fatalf("delete missing machine: %v", err)
@ -533,6 +537,12 @@ func testConfig(root string) appconfig.Config {
}
}
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil
}
}
func listenTestPort(t *testing.T, port int) net.Listener {
t.Helper()

View file

@ -187,9 +187,13 @@ func isZeroChunk(chunk []byte) bool {
}
func defaultMachinePorts() []contracthost.MachinePort {
return buildMachinePorts(0, 0)
}
func buildMachinePorts(sshRelayPort, vncRelayPort uint16) []contracthost.MachinePort {
return []contracthost.MachinePort{
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, Protocol: contracthost.PortProtocolTCP},
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, Protocol: contracthost.PortProtocolTCP},
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP},
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP},
}
}
@ -247,7 +251,14 @@ func (d *Daemon) mergedGuestConfig(config *contracthost.GuestConfig) (*contracth
}
merged := &contracthost.GuestConfig{
AuthorizedKeys: authorizedKeys,
AuthorizedKeys: authorizedKeys,
TrustedUserCAKeys: nil,
}
if strings.TrimSpace(d.config.GuestLoginCAPublicKey) != "" {
merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, d.config.GuestLoginCAPublicKey)
}
if config != nil {
merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, config.TrustedUserCAKeys...)
}
if config != nil && config.LoginWebhook != nil {
loginWebhook := *config.LoginWebhook
@ -260,7 +271,7 @@ func hasGuestConfig(config *contracthost.GuestConfig) bool {
if config == nil {
return false
}
return len(config.AuthorizedKeys) > 0 || config.LoginWebhook != nil
return len(config.AuthorizedKeys) > 0 || len(config.TrustedUserCAKeys) > 0 || config.LoginWebhook != nil
}
func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error {
@ -286,6 +297,17 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
}
}
if len(config.TrustedUserCAKeys) > 0 {
trustedCAPath := filepath.Join(stagingDir, "trusted_user_ca_keys")
payload := []byte(strings.Join(config.TrustedUserCAKeys, "\n") + "\n")
if err := os.WriteFile(trustedCAPath, payload, 0o644); err != nil {
return fmt.Errorf("write trusted_user_ca_keys staging file: %w", err)
}
if err := replaceExt4File(ctx, imagePath, trustedCAPath, "/etc/microagent/trusted_user_ca_keys"); err != nil {
return err
}
}
if config.LoginWebhook != nil {
guestConfigPath := filepath.Join(stagingDir, "guest-config.json")
payload, err := json.Marshal(config)
@ -363,16 +385,17 @@ func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID {
func machineToContract(record model.MachineRecord) contracthost.Machine {
return contracthost.Machine{
ID: record.ID,
Artifact: record.Artifact,
SystemVolumeID: record.SystemVolumeID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
RuntimeHost: record.RuntimeHost,
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
Phase: record.Phase,
Error: record.Error,
CreatedAt: record.CreatedAt,
StartedAt: record.StartedAt,
ID: record.ID,
Artifact: record.Artifact,
SystemVolumeID: record.SystemVolumeID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
RuntimeHost: record.RuntimeHost,
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
GuestSSHPublicKey: record.GuestSSHPublicKey,
Phase: record.Phase,
Error: record.Error,
CreatedAt: record.CreatedAt,
StartedAt: record.StartedAt,
}
}
@ -427,6 +450,11 @@ func validateGuestConfig(config *contracthost.GuestConfig) error {
return fmt.Errorf("guest_config.authorized_keys[%d] is required", i)
}
}
for i, key := range config.TrustedUserCAKeys {
if strings.TrimSpace(key) == "" {
return fmt.Errorf("guest_config.trusted_user_ca_keys[%d] is required", i)
}
}
if config.LoginWebhook != nil {
if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil {
return err

View file

@ -0,0 +1,51 @@
package daemon
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"golang.org/x/crypto/ssh"
)
func readGuestSSHPublicKey(ctx context.Context, runtimeHost string) (string, error) {
host := strings.TrimSpace(runtimeHost)
if host == "" {
return "", fmt.Errorf("runtime host is required")
}
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
defer cancel()
targetAddr := net.JoinHostPort(host, strconv.Itoa(int(defaultSSHPort)))
netConn, err := (&net.Dialer{}).DialContext(probeCtx, "tcp", targetAddr)
if err != nil {
return "", fmt.Errorf("dial guest ssh for host key: %w", err)
}
defer func() {
_ = netConn.Close()
}()
var captured ssh.PublicKey
clientConfig := &ssh.ClientConfig{
User: "host-key-probe",
Auth: []ssh.AuthMethod{ssh.Password("invalid")},
HostKeyAlgorithms: []string{ssh.KeyAlgoED25519},
HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error {
captured = key
return fmt.Errorf("guest ssh host key captured")
},
Timeout: defaultGuestDialTimeout,
ClientVersion: "SSH-2.0-agentcomputer-firecracker-host",
}
_, _, _, err = ssh.NewClientConn(netConn, targetAddr, clientConfig)
if captured == nil {
if err == nil {
return "", fmt.Errorf("guest ssh host key probe returned without a host key")
}
return "", fmt.Errorf("handshake guest ssh for host key: %w", err)
}
return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(captured))), nil
}

View file

@ -99,10 +99,16 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = ports
record.GuestSSHPublicKey = guestSSHPublicKey
record.Phase = contracthost.MachinePhaseRunning
record.Error = ""
record.PID = state.PID
@ -112,7 +118,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
if err := d.ensureMachineRelays(ctx, record); err != nil {
d.stopMachineRelays(id)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
d.stopMachineRelays(id)
d.stopPublishedPortsForMachine(id)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
@ -238,10 +250,14 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
return err
}
if reconciled.Phase == contracthost.MachinePhaseRunning {
if err := d.ensureMachineRelays(ctx, reconciled); err != nil {
return err
}
if err := d.ensurePublishedPortsForMachine(ctx, *reconciled); err != nil {
return err
}
} else {
d.stopMachineRelays(reconciled.ID)
d.stopPublishedPortsForMachine(reconciled.ID)
}
}
@ -322,6 +338,9 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach
return err
}
if record.Phase == contracthost.MachinePhaseRunning {
if err := d.ensureMachineRelays(ctx, record); err != nil {
return err
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
return err
}
@ -375,12 +394,16 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
return nil, err
}
if state.Phase == firecracker.PhaseRunning {
if err := d.ensureMachineRelays(ctx, record); err != nil {
return nil, err
}
return record, nil
}
if err := d.runtime.Delete(ctx, *state); err != nil {
return nil, err
}
d.stopMachineRelays(record.ID)
d.stopPublishedPortsForMachine(record.ID)
record.Phase = contracthost.MachinePhaseFailed
record.Error = state.Error
@ -396,6 +419,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
}
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
d.stopMachineRelays(record.ID)
d.stopPublishedPortsForMachine(record.ID)
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
return err
@ -426,6 +450,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
}
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
d.stopMachineRelays(record.ID)
d.stopPublishedPortsForMachine(record.ID)
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
return err

View file

@ -0,0 +1,184 @@
package daemon
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const (
minMachineSSHRelayPort = uint16(40000)
maxMachineSSHRelayPort = uint16(44999)
minMachineVNCRelayPort = uint16(45000)
maxMachineVNCRelayPort = uint16(49999)
)
func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string {
return string(machineID) + ":" + string(name)
}
func machineRelayHostPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 {
for _, port := range record.Ports {
if port.Name == name {
return port.HostPort
}
}
return 0
}
func machineRelayGuestPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 {
for _, port := range record.Ports {
if port.Name == name {
return port.Port
}
}
switch name {
case contracthost.MachinePortNameVNC:
return defaultVNCPort
default:
return defaultSSHPort
}
}
func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contracthost.MachineID, name contracthost.MachinePortName) (map[uint16]struct{}, error) {
records, err := d.store.ListMachines(ctx)
if err != nil {
return nil, err
}
used := make(map[uint16]struct{}, len(records))
for _, record := range records {
if record.ID == machineID {
continue
}
if port := machineRelayHostPort(record, name); port != 0 {
used[port] = struct{}{}
}
}
return used, nil
}
func (d *Daemon) allocateMachineRelayProxy(
ctx context.Context,
current model.MachineRecord,
name contracthost.MachinePortName,
runtimeHost string,
guestPort uint16,
minPort uint16,
maxPort uint16,
) (uint16, error) {
existingPort := machineRelayHostPort(current, name)
if existingPort != 0 {
if err := d.startMachineRelayProxy(current.ID, name, existingPort, runtimeHost, guestPort); err == nil {
return existingPort, nil
} else if !isAddrInUseError(err) {
return 0, err
}
}
used, err := d.usedMachineRelayPorts(ctx, current.ID, name)
if err != nil {
return 0, err
}
if existingPort != 0 {
used[existingPort] = struct{}{}
}
for hostPort := minPort; hostPort <= maxPort; hostPort++ {
if _, exists := used[hostPort]; exists {
continue
}
if err := d.startMachineRelayProxy(current.ID, name, hostPort, runtimeHost, guestPort); err != nil {
if isAddrInUseError(err) {
continue
}
return 0, err
}
return hostPort, nil
}
return 0, fmt.Errorf("no relay ports are available in range %d-%d", minPort, maxPort)
}
func (d *Daemon) ensureMachineRelays(ctx context.Context, record *model.MachineRecord) error {
if record == nil {
return fmt.Errorf("machine record is required")
}
if record.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(record.RuntimeHost) == "" {
return nil
}
d.relayAllocMu.Lock()
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameSSH, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameSSH), minMachineSSHRelayPort, maxMachineSSHRelayPort)
var vncRelayPort uint16
if err == nil {
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameVNC, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameVNC), minMachineVNCRelayPort, maxMachineVNCRelayPort)
}
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelays(record.ID)
return err
}
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
if err := d.store.UpdateMachine(ctx, *record); err != nil {
d.stopMachineRelays(record.ID)
return err
}
return nil
}
func (d *Daemon) startMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName, hostPort uint16, runtimeHost string, guestPort uint16) error {
targetHost := strings.TrimSpace(runtimeHost)
if targetHost == "" {
return fmt.Errorf("runtime host is required for machine relay %q", machineID)
}
key := machineRelayListenerKey(machineID, name)
d.machineRelaysMu.Lock()
if _, exists := d.machineRelayListeners[key]; exists {
d.machineRelaysMu.Unlock()
return nil
}
listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(hostPort)))
if err != nil {
d.machineRelaysMu.Unlock()
return fmt.Errorf("listen on machine relay port %d: %w", hostPort, err)
}
d.machineRelayListeners[key] = listener
d.machineRelaysMu.Unlock()
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(guestPort)))
go serveTCPProxy(listener, targetAddr)
return nil
}
func (d *Daemon) stopMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName) {
key := machineRelayListenerKey(machineID, name)
d.machineRelaysMu.Lock()
listener, ok := d.machineRelayListeners[key]
if ok {
delete(d.machineRelayListeners, key)
}
d.machineRelaysMu.Unlock()
if ok {
_ = listener.Close()
}
}
func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) {
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH)
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC)
}
func isAddrInUseError(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error()), "address already in use")
}

View file

@ -177,11 +177,11 @@ func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtime
d.publishedPortsMu.Unlock()
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port)))
go d.servePublishedPortProxy(port.ID, listener, targetAddr)
go serveTCPProxy(listener, targetAddr)
return nil
}
func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) {
func serveTCPProxy(listener net.Listener, targetAddr string) {
for {
conn, err := listener.Accept()
if err != nil {
@ -190,11 +190,11 @@ func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, li
}
continue
}
go proxyPublishedPortConnection(conn, targetAddr)
go proxyTCPConnection(conn, targetAddr)
}
}
func proxyPublishedPortConnection(source net.Conn, targetAddr string) {
func proxyTCPConnection(source net.Conn, targetAddr string) {
defer func() {
_ = source.Close()
}()

View file

@ -86,6 +86,7 @@ func TestCreatePublishedPortSerializesHostPortAllocationAcrossMachines(t *testin
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
for _, machineID := range []contracthost.MachineID{"vm-1", "vm-2"} {
if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{
@ -170,6 +171,7 @@ func TestGetStorageReportHandlesSparseSnapshotPathsAndIncludesPublishedPortPool(
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
response, err := hostDaemon.GetStorageReport(context.Background())
if err != nil {
@ -209,6 +211,7 @@ func TestReconcileSnapshotPreservesArtifactsOnUnexpectedStoreError(t *testing.T)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
snapshotID := contracthost.SnapshotID("snap-1")
operation := model.OperationRecord{
@ -249,6 +252,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
operation := model.OperationRecord{
MachineID: "vm-1",
@ -291,6 +295,7 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
machineID := contracthost.MachineID("vm-1")
snapshotID := contracthost.SnapshotID("snap-1")
@ -346,6 +351,7 @@ func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
t.Cleanup(func() {
hostDaemon.stopPublishedPortProxy("port-1")

View file

@ -260,6 +260,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
clearOperation = true
return nil, fmt.Errorf("reconfigure restored guest identity: %w", err)
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost)
if err != nil {
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("read restored guest ssh host key: %w", err)
}
systemVolumeID := d.systemVolumeID(req.MachineID)
now := time.Now().UTC()
@ -277,22 +284,42 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
machineRecord := model.MachineRecord{
ID: req.MachineID,
Artifact: snap.Artifact,
SystemVolumeID: systemVolumeID,
RuntimeHost: machineState.RuntimeHost,
TapDevice: machineState.TapName,
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseRunning,
PID: machineState.PID,
SocketPath: machineState.SocketPath,
CreatedAt: now,
StartedAt: machineState.StartedAt,
ID: req.MachineID,
Artifact: snap.Artifact,
SystemVolumeID: systemVolumeID,
RuntimeHost: machineState.RuntimeHost,
TapDevice: machineState.TapName,
Ports: defaultMachinePorts(),
GuestSSHPublicKey: guestSSHPublicKey,
Phase: contracthost.MachinePhaseRunning,
PID: machineState.PID,
SocketPath: machineState.SocketPath,
CreatedAt: now,
StartedAt: machineState.StartedAt,
}
d.relayAllocMu.Lock()
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
var vncRelayPort uint16
if err == nil {
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
}
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelays(machineRecord.ID)
return nil, err
}
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
startedRelays := true
defer func() {
if startedRelays {
d.stopMachineRelays(machineRecord.ID)
}
}()
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
return nil, err
}
startedRelays = false
clearOperation = true
return &contracthost.RestoreSnapshotResponse{
Machine: machineToContract(machineRecord),