From 07975fb459ec5cfff00ab64ecee81d0b12262bdd Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Fri, 10 Apr 2026 02:26:43 +0000 Subject: [PATCH] feat: vsock mmds snapshot --- internal/daemon/create.go | 3 +- internal/daemon/daemon.go | 23 ++-- internal/daemon/daemon_test.go | 105 +++++++++++++--- internal/daemon/guest_identity.go | 39 +++++- internal/daemon/guest_personalization.go | 135 +++++++++++++++++++++ internal/daemon/lifecycle.go | 5 +- internal/daemon/readiness.go | 41 ------- internal/daemon/review_regressions_test.go | 40 +++--- internal/daemon/snapshot.go | 56 +-------- internal/firecracker/api_test.go | 3 +- internal/firecracker/configure_test.go | 68 +++++++++++ internal/firecracker/launch.go | 9 +- internal/firecracker/runtime.go | 11 ++ 13 files changed, 390 insertions(+), 148 deletions(-) create mode 100644 internal/daemon/guest_personalization.go diff --git a/internal/daemon/create.go b/internal/daemon/create.go index 9449970..d89b859 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -7,10 +7,10 @@ import ( "path/filepath" "time" + contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" - contracthost "github.com/getcompanion-ai/computer-host/contract" ) func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) { @@ -184,6 +184,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo KernelArgs: defaultGuestKernelArgs, Drives: drives, MMDS: mmds, + Vsock: guestVsockSpec(machineID), } if err := spec.Validate(); err != nil { return firecracker.MachineSpec{}, err diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index afcb343..0d51f2b 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -8,22 +8,21 @@ import ( "sync" "time" + contracthost "github.com/getcompanion-ai/computer-host/contract" appconfig "github.com/getcompanion-ai/computer-host/internal/config" "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" - contracthost "github.com/getcompanion-ai/computer-host/contract" ) const ( - defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off" - defaultGuestMemoryMiB = int64(512) - defaultGuestVCPUs = int64(1) - defaultSSHPort = uint16(2222) - defaultVNCPort = uint16(6080) - defaultCopyBufferSize = 1024 * 1024 - defaultGuestDialTimeout = 500 * time.Millisecond - defaultGuestReadyPollInterval = 100 * time.Millisecond - defaultGuestReadyTimeout = 30 * time.Second + defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off" + defaultGuestMemoryMiB = int64(512) + defaultGuestVCPUs = int64(1) + defaultSSHPort = uint16(2222) + defaultVNCPort = uint16(6080) + defaultCopyBufferSize = 1024 * 1024 + defaultGuestDialTimeout = 500 * time.Millisecond ) type Runtime interface { @@ -34,6 +33,7 @@ type Runtime interface { Resume(context.Context, firecracker.MachineState) error CreateSnapshot(context.Context, firecracker.MachineState, firecracker.SnapshotPaths) error RestoreBoot(context.Context, firecracker.SnapshotLoadSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error) + PutMMDS(context.Context, firecracker.MachineState, any) error } type Daemon struct { @@ -44,6 +44,7 @@ type Daemon struct { reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error readGuestSSHPublicKey func(context.Context, string) (string, error) syncGuestFilesystem func(context.Context, string) error + personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error locksMu sync.Mutex machineLocks map[contracthost.MachineID]*sync.Mutex @@ -78,6 +79,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err runtime: runtime, reconfigureGuestIdentity: nil, readGuestSSHPublicKey: nil, + personalizeGuest: nil, machineLocks: make(map[contracthost.MachineID]*sync.Mutex), artifactLocks: make(map[string]*sync.Mutex), machineRelayListeners: make(map[string]net.Listener), @@ -86,6 +88,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH daemon.readGuestSSHPublicKey = readGuestSSHPublicKey daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH + daemon.personalizeGuest = daemon.personalizeGuestConfig if err := daemon.ensureBackendSSHKeyPair(); err != nil { return nil, err } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 8743306..790da33 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -13,11 +13,11 @@ import ( "testing" "time" + contracthost "github.com/getcompanion-ai/computer-host/contract" appconfig "github.com/getcompanion-ai/computer-host/internal/config" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" - contracthost "github.com/getcompanion-ai/computer-host/contract" ) type fakeRuntime struct { @@ -27,6 +27,7 @@ type fakeRuntime struct { deleteCalls []firecracker.MachineState lastSpec firecracker.MachineSpec lastLoadSpec firecracker.SnapshotLoadSpec + mmdsWrites []any } func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) { @@ -64,6 +65,11 @@ func (f *fakeRuntime) RestoreBoot(_ context.Context, spec firecracker.SnapshotLo return &f.bootState, nil } +func (f *fakeRuntime) PutMMDS(_ context.Context, _ firecracker.MachineState, data any) error { + f.mmdsWrites = append(f.mmdsWrites, data) + return nil +} + func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -173,6 +179,15 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if runtime.lastSpec.MMDS == nil { t.Fatalf("expected MMDS configuration on machine spec") } + if runtime.lastSpec.Vsock == nil { + t.Fatalf("expected vsock configuration on machine spec") + } + if runtime.lastSpec.Vsock.ID != defaultGuestPersonalizationVsockID { + t.Fatalf("vsock id mismatch: got %q", runtime.lastSpec.Vsock.ID) + } + if runtime.lastSpec.Vsock.CID < minGuestVsockCID { + t.Fatalf("vsock cid mismatch: got %d", runtime.lastSpec.Vsock.CID) + } if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 { t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version) } @@ -286,6 +301,70 @@ func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) { } } +func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { _ = sshListener.Close() }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { _ = vncListener.Close() }() + + startedAt := time.Unix(1700000100, 0).UTC() + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") }) + + personalized := false + hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error { + personalized = true + if record.ID != "vm-starting" { + t.Fatalf("personalized machine mismatch: got %q", record.ID) + } + if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 { + t.Fatalf("personalized state mismatch: %#v", state) + } + return nil + } + stubGuestSSHPublicKeyReader(hostDaemon) + + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: "vm-starting", + SystemVolumeID: "vm-starting-system", + RuntimeHost: "127.0.0.1", + TapDevice: "fctap-starting", + Ports: defaultMachinePorts(), + Phase: contracthost.MachinePhaseStarting, + PID: 4321, + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-starting", "root", "run", "firecracker.sock"), + CreatedAt: time.Now().UTC(), + StartedAt: &startedAt, + }); err != nil { + t.Fatalf("create machine: %v", err) + } + + response, err := hostDaemon.GetMachine(context.Background(), "vm-starting") + if err != nil { + t.Fatalf("GetMachine returned error: %v", err) + } + if !personalized { + t.Fatalf("guest personalization was not called") + } + if response.Machine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning) + } + if response.Machine.GuestSSHPublicKey == "" { + t.Fatalf("guest ssh public key should be recorded after convergence") + } +} + func TestNewEnsuresBackendSSHKeyPair(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -406,6 +485,9 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { if response.Machine.ID != "restored" { t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) } + if response.Machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) + } if runtime.restoreCalls != 1 { t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) } @@ -462,13 +544,8 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { t.Fatalf("create daemon: %v", err) } stubGuestSSHPublicKeyReader(hostDaemon) - var reconfiguredHost string - var reconfiguredMachine contracthost.MachineID - var reconfiguredConfig *contracthost.GuestConfig hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error { - reconfiguredHost = host - reconfiguredMachine = machineID - reconfiguredConfig = cloneGuestConfig(guestConfig) + t.Fatalf("restore snapshot should not synchronously reconfigure guest identity, host=%q machine=%q guest_config=%#v", host, machineID, guestConfig) return nil } @@ -509,6 +586,9 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { if response.Machine.ID != "restored" { t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) } + if response.Machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) + } if runtime.restoreCalls != 1 { t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) } @@ -527,20 +607,17 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { }), "kernel")) { t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath) } - if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" { - t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine) - } - if reconfiguredConfig == nil || reconfiguredConfig.Hostname != "restored-shell" { - t.Fatalf("guest identity hostname mismatch: %#v", reconfiguredConfig) - } machine, err := fileStore.GetMachine(context.Background(), "restored") if err != nil { t.Fatalf("get restored machine: %v", err) } - if machine.Phase != contracthost.MachinePhaseRunning { + if machine.Phase != contracthost.MachinePhaseStarting { t.Fatalf("restored machine phase mismatch: got %q", machine.Phase) } + if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" { + t.Fatalf("stored guest config mismatch: %#v", machine.GuestConfig) + } if len(machine.UserVolumeIDs) != 1 { t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs) } diff --git a/internal/daemon/guest_identity.go b/internal/daemon/guest_identity.go index e3a7721..96e92e8 100644 --- a/internal/daemon/guest_identity.go +++ b/internal/daemon/guest_identity.go @@ -2,6 +2,7 @@ package daemon import ( "context" + "encoding/json" "fmt" "os/exec" "strconv" @@ -19,10 +20,24 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos if machineName == "" { return fmt.Errorf("machine id is required") } + mmds, err := d.guestMetadataSpec(machineID, guestConfig) + if err != nil { + return err + } + envelope, ok := mmds.Data.(guestMetadataEnvelope) + if !ok { + return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data) + } + payloadBytes, err := json.Marshal(envelope.Latest.MetaData) + if err != nil { + return fmt.Errorf("marshal guest metadata payload: %w", err) + } privateKeyPath := d.backendSSHPrivateKeyPath() remoteScript := fmt.Sprintf(`set -euo pipefail -machine_name=%s +payload=%s +install -d -m 0755 /etc/microagent +machine_name="$(printf '%%s' "$payload" | jq -r '.hostname // .machine_id // empty')" printf '%%s\n' "$machine_name" >/etc/microagent/machine-name printf '%%s\n' "$machine_name" >/etc/hostname cat >/etc/hosts </dev/null 2>&1 || true -`, strconv.Quote(machineName)) +if printf '%%s' "$payload" | jq -e '.authorized_keys | length > 0' >/dev/null 2>&1; then + install -d -m 0700 -o node -g node /home/node/.ssh + printf '%%s' "$payload" | jq -r '.authorized_keys[]' >/home/node/.ssh/authorized_keys + chmod 0600 /home/node/.ssh/authorized_keys + chown node:node /home/node/.ssh/authorized_keys + printf '%%s' "$payload" | jq -r '.authorized_keys[]' >/etc/microagent/authorized_keys + chmod 0600 /etc/microagent/authorized_keys +else + rm -f /home/node/.ssh/authorized_keys /etc/microagent/authorized_keys +fi +if printf '%%s' "$payload" | jq -e '.trusted_user_ca_keys | length > 0' >/dev/null 2>&1; then + printf '%%s' "$payload" | jq -r '.trusted_user_ca_keys[]' >/etc/microagent/trusted_user_ca_keys + chmod 0644 /etc/microagent/trusted_user_ca_keys +else + rm -f /etc/microagent/trusted_user_ca_keys +fi +printf '%%s' "$payload" | jq '{authorized_keys, trusted_user_ca_keys, login_webhook}' >/etc/microagent/guest-config.json +chmod 0600 /etc/microagent/guest-config.json +`, strconv.Quote(string(payloadBytes))) cmd := exec.CommandContext( ctx, @@ -43,6 +76,7 @@ hostname "$machine_name" >/dev/null 2>&1 || true "-o", "UserKnownHostsFile=/dev/null", "-o", "IdentitiesOnly=yes", "-o", "BatchMode=yes", + "-o", "ConnectTimeout=2", "-p", strconv.Itoa(int(defaultSSHPort)), "node@"+runtimeHost, "sudo bash -lc "+shellSingleQuote(remoteScript), @@ -68,6 +102,7 @@ func (d *Daemon) syncGuestFilesystemOverSSH(ctx context.Context, runtimeHost str "-o", "UserKnownHostsFile=/dev/null", "-o", "IdentitiesOnly=yes", "-o", "BatchMode=yes", + "-o", "ConnectTimeout=2", "-p", strconv.Itoa(int(defaultSSHPort)), "node@"+runtimeHost, "sudo bash -lc "+shellSingleQuote("sync"), diff --git a/internal/daemon/guest_personalization.go b/internal/daemon/guest_personalization.go new file mode 100644 index 0000000..b594cbd --- /dev/null +++ b/internal/daemon/guest_personalization.go @@ -0,0 +1,135 @@ +package daemon + +import ( + "bufio" + "context" + "crypto/sha256" + "encoding/binary" + "encoding/json" + "fmt" + "net" + "path/filepath" + "strconv" + "strings" + "time" + + contracthost "github.com/getcompanion-ai/computer-host/contract" + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" +) + +const ( + defaultGuestPersonalizationVsockID = "microagent-personalizer" + defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock" + defaultGuestPersonalizationVsockPort = uint32(1024) + defaultGuestPersonalizationTimeout = 2 * time.Second + minGuestVsockCID = uint32(3) + maxGuestVsockCID = uint32(1<<31 - 1) +) + +func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec { + return &firecracker.VsockSpec{ + ID: defaultGuestPersonalizationVsockID, + CID: guestVsockCID(machineID), + Path: defaultGuestPersonalizationVsockName, + } +} + +func guestVsockCID(machineID contracthost.MachineID) uint32 { + sum := sha256.Sum256([]byte(machineID)) + space := maxGuestVsockCID - minGuestVsockCID + 1 + return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space +} + +func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error { + if record == nil { + return fmt.Errorf("machine record is required") + } + + personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout) + defer cancel() + + mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig) + if err != nil { + return err + } + envelope, ok := mmds.Data.(guestMetadataEnvelope) + if !ok { + return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data) + } + + if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil { + return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err)) + } + if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil { + return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err)) + } + return nil +} + +func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error { + fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig) + if fallbackErr == nil { + return nil + } + return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr) +} + +func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error { + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal guest personalization payload: %w", err) + } + + vsockPath, err := guestVsockHostPath(state) + if err != nil { + return err + } + connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath) + if err != nil { + return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err) + } + defer func() { + _ = connection.Close() + }() + setConnectionDeadline(ctx, connection) + + reader := bufio.NewReader(connection) + if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil { + return fmt.Errorf("write vsock connect request: %w", err) + } + response, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("read vsock connect response: %w", err) + } + if !strings.HasPrefix(strings.TrimSpace(response), "OK ") { + return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response)) + } + + if _, err := connection.Write(append(payloadBytes, '\n')); err != nil { + return fmt.Errorf("write guest personalization payload: %w", err) + } + response, err = reader.ReadString('\n') + if err != nil { + return fmt.Errorf("read guest personalization response: %w", err) + } + if strings.TrimSpace(response) != "OK" { + return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response)) + } + return nil +} + +func guestVsockHostPath(state firecracker.MachineState) (string, error) { + if state.PID < 1 { + return "", fmt.Errorf("firecracker pid is required for guest vsock host path") + } + return filepath.Join("/proc", strconv.Itoa(state.PID), "root", "run", defaultGuestPersonalizationVsockName), nil +} + +func setConnectionDeadline(ctx context.Context, connection net.Conn) { + if deadline, ok := ctx.Deadline(); ok { + _ = connection.SetDeadline(deadline) + return + } + _ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout)) +} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index 6b54126..5bbf13c 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -9,10 +9,10 @@ import ( "strings" "time" + contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" - contracthost "github.com/getcompanion-ai/computer-host/contract" ) func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) { @@ -387,6 +387,9 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma if !ready { return record, nil } + if err := d.personalizeGuest(ctx, record, *state); err != nil { + return d.failMachineStartup(ctx, record, err.Error()) + } guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) if err != nil { return d.failMachineStartup(ctx, record, err.Error()) diff --git a/internal/daemon/readiness.go b/internal/daemon/readiness.go index 63644b9..a1a629b 100644 --- a/internal/daemon/readiness.go +++ b/internal/daemon/readiness.go @@ -6,51 +6,10 @@ import ( "net" "strconv" "strings" - "time" contracthost "github.com/getcompanion-ai/computer-host/contract" ) -func waitForGuestReady(ctx context.Context, host string, ports []contracthost.MachinePort) error { - host = strings.TrimSpace(host) - if host == "" { - return fmt.Errorf("guest runtime host is required") - } - - waitContext, cancel := context.WithTimeout(ctx, defaultGuestReadyTimeout) - defer cancel() - - for _, port := range ports { - if err := waitForGuestPort(waitContext, host, port); err != nil { - return err - } - } - return nil -} - -func waitForGuestPort(ctx context.Context, host string, port contracthost.MachinePort) error { - address := net.JoinHostPort(host, strconv.Itoa(int(port.Port))) - ticker := time.NewTicker(defaultGuestReadyPollInterval) - defer ticker.Stop() - - var lastErr error - for { - probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout) - ready, err := guestPortReady(probeCtx, host, port) - cancel() - if err == nil && ready { - return nil - } - lastErr = err - - select { - case <-ctx.Done(): - return fmt.Errorf("wait for guest port %q on %s: %w (last_err=%v)", port.Name, address, ctx.Err(), lastErr) - case <-ticker.C: - } - } -} - func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) { host = strings.TrimSpace(host) if host == "" { diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 2d3613f..3badab2 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -14,10 +14,10 @@ import ( "testing" "time" + contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" hoststore "github.com/getcompanion-ai/computer-host/internal/store" - contracthost "github.com/getcompanion-ai/computer-host/contract" ) type blockingPublishedPortStore struct { @@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) { } } -func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *testing.T) { +func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) { root := t.TempDir() cfg := testConfig(root) baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -421,15 +421,6 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes extraMachines: exhaustedMachineRelayRecords(), } - sshListener := listenTestPort(t, int(defaultSSHPort)) - defer func() { - _ = sshListener.Close() - }() - vncListener := listenTestPort(t, int(defaultVNCPort)) - defer func() { - _ = vncListener.Close() - }() - startedAt := time.Unix(1700000300, 0).UTC() runtime := &fakeRuntime{ bootState: firecracker.MachineState{ @@ -448,7 +439,10 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes t.Fatalf("create daemon: %v", err) } stubGuestSSHPublicKeyReader(hostDaemon) - hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } + hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error { + t.Fatalf("restore snapshot should not synchronously reconfigure guest identity, host=%q machine=%q guest_config=%#v", host, machineID, guestConfig) + return nil + } artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} kernelPath := filepath.Join(root, "artifact-kernel") @@ -509,7 +503,7 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes }) defer server.Close() - _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{ + response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{ MachineID: "restored-exhausted", Artifact: contracthost.ArtifactRef{ KernelImageURL: server.URL + "/kernel", @@ -526,18 +520,20 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes }, }, }) - if err == nil || !strings.Contains(err.Error(), "allocate relay ports for restored machine") { - t.Fatalf("RestoreSnapshot error = %v, want relay allocation failure", err) + if err != nil { + t.Fatalf("RestoreSnapshot returned error: %v", err) } - - if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); !errors.Is(err, hoststore.ErrNotFound) { - t.Fatalf("restored system volume record should be deleted, get err = %v", err) + if response.Machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting) } - if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); !os.IsNotExist(err) { - t.Fatalf("restored system disk should be removed, stat err = %v", err) + if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil { + t.Fatalf("restored system volume record should exist: %v", err) } - if len(runtime.deleteCalls) != 1 { - t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls)) + if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); err != nil { + t.Fatalf("restored system disk should exist: %v", err) + } + if len(runtime.deleteCalls) != 0 { + t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls)) } assertOperationCount(t, baseStore, 0) } diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index 93188ff..7cd593e 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -12,10 +12,10 @@ import ( "strings" "time" + contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" - contracthost "github.com/getcompanion-ai/computer-host/contract" ) func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) { @@ -332,6 +332,9 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn restoredDrivePaths[driveID] = volumePath } + // Do not force vsock_override on restore: Firecracker rejects it for old + // snapshots without a vsock device, and the jailed /run path already + // relocates safely for snapshots created with the new vsock-backed guest. loadSpec := firecracker.SnapshotLoadSpec{ ID: firecracker.MachineID(req.MachineID), SnapshotPath: vmstateArtifact.LocalPath, @@ -349,27 +352,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn return nil, fmt.Errorf("restore boot: %w", err) } - // Wait for guest to become ready - if err := waitForGuestReady(ctx, machineState.RuntimeHost, defaultMachinePorts()); err != nil { - _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) - clearOperation = true - return nil, fmt.Errorf("wait for restored guest ready: %w", err) - } - if err := d.reconfigureGuestIdentity(ctx, machineState.RuntimeHost, req.MachineID, guestConfig); err != nil { - _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) - clearOperation = true - return nil, fmt.Errorf("reconfigure restored guest identity: %w", err) - } - guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost) - if err != nil { - _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) - clearOperation = true - return nil, fmt.Errorf("read restored guest ssh host key: %w", err) - } - systemVolumeID := d.systemVolumeID(req.MachineID) now := time.Now().UTC() @@ -419,38 +401,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn RuntimeHost: machineState.RuntimeHost, TapDevice: machineState.TapName, Ports: defaultMachinePorts(), - GuestSSHPublicKey: guestSSHPublicKey, - Phase: contracthost.MachinePhaseRunning, + GuestSSHPublicKey: "", + Phase: contracthost.MachinePhaseStarting, PID: machineState.PID, SocketPath: machineState.SocketPath, CreatedAt: now, StartedAt: machineState.StartedAt, } - d.relayAllocMu.Lock() - sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort) - var vncRelayPort uint16 - if err == nil { - vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort) - } - d.relayAllocMu.Unlock() - if err != nil { - d.stopMachineRelays(machineRecord.ID) - for _, restoredVolumeID := range restoredUserVolumeIDs { - _ = d.store.DeleteVolume(context.Background(), restoredVolumeID) - } - _ = d.store.DeleteVolume(context.Background(), systemVolumeID) - _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) - clearOperation = true - return nil, fmt.Errorf("allocate relay ports for restored machine: %w", err) - } - machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) - startedRelays := true - defer func() { - if startedRelays { - d.stopMachineRelays(machineRecord.ID) - } - }() if err := d.store.CreateMachine(ctx, machineRecord); err != nil { for _, restoredVolumeID := range restoredUserVolumeIDs { _ = d.store.DeleteVolume(context.Background(), restoredVolumeID) @@ -462,7 +419,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn return nil, err } - startedRelays = false clearOperation = true return &contracthost.RestoreSnapshotResponse{ Machine: machineToContract(machineRecord), diff --git a/internal/firecracker/api_test.go b/internal/firecracker/api_test.go index e336bc2..6a3c900 100644 --- a/internal/firecracker/api_test.go +++ b/internal/firecracker/api_test.go @@ -38,6 +38,7 @@ func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) { HostDevName: "fctap7", }, }, + VsockOverride: &VsockOverride{UDSPath: "/run/microagent-personalizer.vsock"}, }) if err != nil { t.Fatalf("put snapshot load: %v", err) @@ -47,7 +48,7 @@ func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) { t.Fatalf("request path mismatch: got %q want %q", gotPath, "/snapshot/load") } - want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}]}" + want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}],\"vsock_override\":{\"uds_path\":\"/run/microagent-personalizer.vsock\"}}" if gotBody != want { t.Fatalf("request body mismatch:\n got: %s\nwant: %s", gotBody, want) } diff --git a/internal/firecracker/configure_test.go b/internal/firecracker/configure_test.go index 7fe00a5..b806666 100644 --- a/internal/firecracker/configure_test.go +++ b/internal/firecracker/configure_test.go @@ -168,6 +168,74 @@ func TestConfigureMachineConfiguresMMDSBeforeStart(t *testing.T) { } } +func TestConfigureMachineConfiguresVsockBeforeStart(t *testing.T) { + var requests []capturedRequest + + socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + requests = append(requests, capturedRequest{ + Method: r.Method, + Path: r.URL.Path, + Body: string(body), + }) + w.WriteHeader(http.StatusNoContent) + }) + defer shutdown() + + client := newAPIClient(socketPath) + spec := MachineSpec{ + ID: "vm-3", + VCPUs: 1, + MemoryMiB: 512, + KernelImagePath: "/kernel", + RootFSPath: "/rootfs", + Vsock: &VsockSpec{ + ID: "microagent-personalizer", + CID: 42, + Path: "/run/microagent-personalizer.vsock", + }, + } + paths := machinePaths{JailedSerialLogPath: "/logs/serial.log"} + network := NetworkAllocation{ + InterfaceID: defaultInterfaceID, + TapName: "fctap0", + GuestMAC: "06:00:ac:10:00:02", + } + + if err := configureMachine(context.Background(), client, paths, spec, network); err != nil { + t.Fatalf("configure machine: %v", err) + } + + gotPaths := make([]string, 0, len(requests)) + for _, request := range requests { + gotPaths = append(gotPaths, request.Path) + } + wantPaths := []string{ + "/machine-config", + "/boot-source", + "/drives/root_drive", + "/network-interfaces/net0", + "/entropy", + "/serial", + "/vsock", + "/actions", + } + if len(gotPaths) != len(wantPaths) { + t.Fatalf("request count mismatch: got %d want %d (%v)", len(gotPaths), len(wantPaths), gotPaths) + } + for i := range wantPaths { + if gotPaths[i] != wantPaths[i] { + t.Fatalf("request %d mismatch: got %q want %q", i, gotPaths[i], wantPaths[i]) + } + } + if requests[6].Body != "{\"guest_cid\":42,\"uds_path\":\"/run/microagent-personalizer.vsock\"}" { + t.Fatalf("vsock body mismatch: got %q", requests[6].Body) + } +} + func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) { t.Helper() diff --git a/internal/firecracker/launch.go b/internal/firecracker/launch.go index a01a743..3615655 100644 --- a/internal/firecracker/launch.go +++ b/internal/firecracker/launch.go @@ -132,7 +132,7 @@ func stageMachineFiles(spec MachineSpec, paths machinePaths) (MachineSpec, error if spec.Vsock != nil { vsock := *spec.Vsock - vsock.Path = jailedVSockPath(spec) + vsock.Path = jailedVSockDevicePath(*spec.Vsock) staged.Vsock = &vsock } @@ -244,11 +244,8 @@ func waitForPIDFile(ctx context.Context, pidFilePath string) (int, error) { } } -func jailedVSockPath(spec MachineSpec) string { - if spec.Vsock == nil { - return "" - } - return path.Join(defaultVSockRunDir, filepath.Base(strings.TrimSpace(spec.Vsock.Path))) +func jailedVSockDevicePath(spec VsockSpec) string { + return path.Join(defaultVSockRunDir, filepath.Base(strings.TrimSpace(spec.Path))) } func linkMachineFile(source string, target string) error { diff --git a/internal/firecracker/runtime.go b/internal/firecracker/runtime.go index a4d3c86..ec04de6 100644 --- a/internal/firecracker/runtime.go +++ b/internal/firecracker/runtime.go @@ -331,6 +331,11 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us } } + var vsockOverride *VsockOverride + if loadSpec.Vsock != nil { + vsockOverride = &VsockOverride{UDSPath: jailedVSockDevicePath(*loadSpec.Vsock)} + } + // Load snapshot (replaces the full configure+start sequence) if err := client.PutSnapshotLoad(ctx, SnapshotLoadParams{ SnapshotPath: chrootStatePath, @@ -345,6 +350,7 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us HostDevName: network.TapName, }, }, + VsockOverride: vsockOverride, }); err != nil { cleanup(network, paths, command, firecrackerPID) return nil, fmt.Errorf("load snapshot: %w", err) @@ -369,6 +375,11 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us return &state, nil } +func (r *Runtime) PutMMDS(ctx context.Context, state MachineState, data any) error { + client := newAPIClient(state.SocketPath) + return client.PutMMDS(ctx, data) +} + func processExists(pid int) bool { if pid < 1 { return false