feat: vsock mmds snapshot

This commit is contained in:
Harivansh Rathi 2026-04-10 02:26:43 +00:00
parent 39f8882c30
commit 07975fb459
13 changed files with 390 additions and 148 deletions

View file

@ -7,10 +7,10 @@ import (
"path/filepath"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) {
@ -184,6 +184,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
KernelArgs: defaultGuestKernelArgs,
Drives: drives,
MMDS: mmds,
Vsock: guestVsockSpec(machineID),
}
if err := spec.Validate(); err != nil {
return firecracker.MachineSpec{}, err

View file

@ -8,22 +8,21 @@ import (
"sync"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const (
defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off"
defaultGuestMemoryMiB = int64(512)
defaultGuestVCPUs = int64(1)
defaultSSHPort = uint16(2222)
defaultVNCPort = uint16(6080)
defaultCopyBufferSize = 1024 * 1024
defaultGuestDialTimeout = 500 * time.Millisecond
defaultGuestReadyPollInterval = 100 * time.Millisecond
defaultGuestReadyTimeout = 30 * time.Second
defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off"
defaultGuestMemoryMiB = int64(512)
defaultGuestVCPUs = int64(1)
defaultSSHPort = uint16(2222)
defaultVNCPort = uint16(6080)
defaultCopyBufferSize = 1024 * 1024
defaultGuestDialTimeout = 500 * time.Millisecond
)
type Runtime interface {
@ -34,6 +33,7 @@ type Runtime interface {
Resume(context.Context, firecracker.MachineState) error
CreateSnapshot(context.Context, firecracker.MachineState, firecracker.SnapshotPaths) error
RestoreBoot(context.Context, firecracker.SnapshotLoadSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error)
PutMMDS(context.Context, firecracker.MachineState, any) error
}
type Daemon struct {
@ -44,6 +44,7 @@ type Daemon struct {
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error
readGuestSSHPublicKey func(context.Context, string) (string, error)
syncGuestFilesystem func(context.Context, string) error
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error
locksMu sync.Mutex
machineLocks map[contracthost.MachineID]*sync.Mutex
@ -78,6 +79,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
runtime: runtime,
reconfigureGuestIdentity: nil,
readGuestSSHPublicKey: nil,
personalizeGuest: nil,
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
artifactLocks: make(map[string]*sync.Mutex),
machineRelayListeners: make(map[string]net.Listener),
@ -86,6 +88,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH
daemon.personalizeGuest = daemon.personalizeGuestConfig
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
return nil, err
}

View file

@ -13,11 +13,11 @@ import (
"testing"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type fakeRuntime struct {
@ -27,6 +27,7 @@ type fakeRuntime struct {
deleteCalls []firecracker.MachineState
lastSpec firecracker.MachineSpec
lastLoadSpec firecracker.SnapshotLoadSpec
mmdsWrites []any
}
func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) {
@ -64,6 +65,11 @@ func (f *fakeRuntime) RestoreBoot(_ context.Context, spec firecracker.SnapshotLo
return &f.bootState, nil
}
func (f *fakeRuntime) PutMMDS(_ context.Context, _ firecracker.MachineState, data any) error {
f.mmdsWrites = append(f.mmdsWrites, data)
return nil
}
func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -173,6 +179,15 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
if runtime.lastSpec.MMDS == nil {
t.Fatalf("expected MMDS configuration on machine spec")
}
if runtime.lastSpec.Vsock == nil {
t.Fatalf("expected vsock configuration on machine spec")
}
if runtime.lastSpec.Vsock.ID != defaultGuestPersonalizationVsockID {
t.Fatalf("vsock id mismatch: got %q", runtime.lastSpec.Vsock.ID)
}
if runtime.lastSpec.Vsock.CID < minGuestVsockCID {
t.Fatalf("vsock cid mismatch: got %d", runtime.lastSpec.Vsock.CID)
}
if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 {
t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version)
}
@ -286,6 +301,70 @@ func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) {
}
}
func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() { _ = sshListener.Close() }()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() { _ = vncListener.Close() }()
startedAt := time.Unix(1700000100, 0).UTC()
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") })
personalized := false
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
personalized = true
if record.ID != "vm-starting" {
t.Fatalf("personalized machine mismatch: got %q", record.ID)
}
if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 {
t.Fatalf("personalized state mismatch: %#v", state)
}
return nil
}
stubGuestSSHPublicKeyReader(hostDaemon)
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "vm-starting",
SystemVolumeID: "vm-starting-system",
RuntimeHost: "127.0.0.1",
TapDevice: "fctap-starting",
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseStarting,
PID: 4321,
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-starting", "root", "run", "firecracker.sock"),
CreatedAt: time.Now().UTC(),
StartedAt: &startedAt,
}); err != nil {
t.Fatalf("create machine: %v", err)
}
response, err := hostDaemon.GetMachine(context.Background(), "vm-starting")
if err != nil {
t.Fatalf("GetMachine returned error: %v", err)
}
if !personalized {
t.Fatalf("guest personalization was not called")
}
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
}
if response.Machine.GuestSSHPublicKey == "" {
t.Fatalf("guest ssh public key should be recorded after convergence")
}
}
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -406,6 +485,9 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
if response.Machine.ID != "restored" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
}
if runtime.restoreCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
}
@ -462,13 +544,8 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
var reconfiguredHost string
var reconfiguredMachine contracthost.MachineID
var reconfiguredConfig *contracthost.GuestConfig
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error {
reconfiguredHost = host
reconfiguredMachine = machineID
reconfiguredConfig = cloneGuestConfig(guestConfig)
t.Fatalf("restore snapshot should not synchronously reconfigure guest identity, host=%q machine=%q guest_config=%#v", host, machineID, guestConfig)
return nil
}
@ -509,6 +586,9 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
if response.Machine.ID != "restored" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
}
if runtime.restoreCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
}
@ -527,20 +607,17 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
}), "kernel")) {
t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath)
}
if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" {
t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine)
}
if reconfiguredConfig == nil || reconfiguredConfig.Hostname != "restored-shell" {
t.Fatalf("guest identity hostname mismatch: %#v", reconfiguredConfig)
}
machine, err := fileStore.GetMachine(context.Background(), "restored")
if err != nil {
t.Fatalf("get restored machine: %v", err)
}
if machine.Phase != contracthost.MachinePhaseRunning {
if machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
}
if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" {
t.Fatalf("stored guest config mismatch: %#v", machine.GuestConfig)
}
if len(machine.UserVolumeIDs) != 1 {
t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs)
}

View file

@ -2,6 +2,7 @@ package daemon
import (
"context"
"encoding/json"
"fmt"
"os/exec"
"strconv"
@ -19,10 +20,24 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos
if machineName == "" {
return fmt.Errorf("machine id is required")
}
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
if err != nil {
return err
}
envelope, ok := mmds.Data.(guestMetadataEnvelope)
if !ok {
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
}
payloadBytes, err := json.Marshal(envelope.Latest.MetaData)
if err != nil {
return fmt.Errorf("marshal guest metadata payload: %w", err)
}
privateKeyPath := d.backendSSHPrivateKeyPath()
remoteScript := fmt.Sprintf(`set -euo pipefail
machine_name=%s
payload=%s
install -d -m 0755 /etc/microagent
machine_name="$(printf '%%s' "$payload" | jq -r '.hostname // .machine_id // empty')"
printf '%%s\n' "$machine_name" >/etc/microagent/machine-name
printf '%%s\n' "$machine_name" >/etc/hostname
cat >/etc/hosts <<EOF
@ -33,7 +48,25 @@ ff02::1 ip6-allnodes
ff02::2 ip6-allrouters
EOF
hostname "$machine_name" >/dev/null 2>&1 || true
`, strconv.Quote(machineName))
if printf '%%s' "$payload" | jq -e '.authorized_keys | length > 0' >/dev/null 2>&1; then
install -d -m 0700 -o node -g node /home/node/.ssh
printf '%%s' "$payload" | jq -r '.authorized_keys[]' >/home/node/.ssh/authorized_keys
chmod 0600 /home/node/.ssh/authorized_keys
chown node:node /home/node/.ssh/authorized_keys
printf '%%s' "$payload" | jq -r '.authorized_keys[]' >/etc/microagent/authorized_keys
chmod 0600 /etc/microagent/authorized_keys
else
rm -f /home/node/.ssh/authorized_keys /etc/microagent/authorized_keys
fi
if printf '%%s' "$payload" | jq -e '.trusted_user_ca_keys | length > 0' >/dev/null 2>&1; then
printf '%%s' "$payload" | jq -r '.trusted_user_ca_keys[]' >/etc/microagent/trusted_user_ca_keys
chmod 0644 /etc/microagent/trusted_user_ca_keys
else
rm -f /etc/microagent/trusted_user_ca_keys
fi
printf '%%s' "$payload" | jq '{authorized_keys, trusted_user_ca_keys, login_webhook}' >/etc/microagent/guest-config.json
chmod 0600 /etc/microagent/guest-config.json
`, strconv.Quote(string(payloadBytes)))
cmd := exec.CommandContext(
ctx,
@ -43,6 +76,7 @@ hostname "$machine_name" >/dev/null 2>&1 || true
"-o", "UserKnownHostsFile=/dev/null",
"-o", "IdentitiesOnly=yes",
"-o", "BatchMode=yes",
"-o", "ConnectTimeout=2",
"-p", strconv.Itoa(int(defaultSSHPort)),
"node@"+runtimeHost,
"sudo bash -lc "+shellSingleQuote(remoteScript),
@ -68,6 +102,7 @@ func (d *Daemon) syncGuestFilesystemOverSSH(ctx context.Context, runtimeHost str
"-o", "UserKnownHostsFile=/dev/null",
"-o", "IdentitiesOnly=yes",
"-o", "BatchMode=yes",
"-o", "ConnectTimeout=2",
"-p", strconv.Itoa(int(defaultSSHPort)),
"node@"+runtimeHost,
"sudo bash -lc "+shellSingleQuote("sync"),

View file

@ -0,0 +1,135 @@
package daemon
import (
"bufio"
"context"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
)
const (
defaultGuestPersonalizationVsockID = "microagent-personalizer"
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
defaultGuestPersonalizationVsockPort = uint32(1024)
defaultGuestPersonalizationTimeout = 2 * time.Second
minGuestVsockCID = uint32(3)
maxGuestVsockCID = uint32(1<<31 - 1)
)
func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec {
return &firecracker.VsockSpec{
ID: defaultGuestPersonalizationVsockID,
CID: guestVsockCID(machineID),
Path: defaultGuestPersonalizationVsockName,
}
}
func guestVsockCID(machineID contracthost.MachineID) uint32 {
sum := sha256.Sum256([]byte(machineID))
space := maxGuestVsockCID - minGuestVsockCID + 1
return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space
}
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
if record == nil {
return fmt.Errorf("machine record is required")
}
personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout)
defer cancel()
mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig)
if err != nil {
return err
}
envelope, ok := mmds.Data.(guestMetadataEnvelope)
if !ok {
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
}
if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil {
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err))
}
if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil {
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err))
}
return nil
}
func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error {
fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig)
if fallbackErr == nil {
return nil
}
return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr)
}
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error {
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal guest personalization payload: %w", err)
}
vsockPath, err := guestVsockHostPath(state)
if err != nil {
return err
}
connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath)
if err != nil {
return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
}
defer func() {
_ = connection.Close()
}()
setConnectionDeadline(ctx, connection)
reader := bufio.NewReader(connection)
if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil {
return fmt.Errorf("write vsock connect request: %w", err)
}
response, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("read vsock connect response: %w", err)
}
if !strings.HasPrefix(strings.TrimSpace(response), "OK ") {
return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
}
if _, err := connection.Write(append(payloadBytes, '\n')); err != nil {
return fmt.Errorf("write guest personalization payload: %w", err)
}
response, err = reader.ReadString('\n')
if err != nil {
return fmt.Errorf("read guest personalization response: %w", err)
}
if strings.TrimSpace(response) != "OK" {
return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response))
}
return nil
}
func guestVsockHostPath(state firecracker.MachineState) (string, error) {
if state.PID < 1 {
return "", fmt.Errorf("firecracker pid is required for guest vsock host path")
}
return filepath.Join("/proc", strconv.Itoa(state.PID), "root", "run", defaultGuestPersonalizationVsockName), nil
}
func setConnectionDeadline(ctx context.Context, connection net.Conn) {
if deadline, ok := ctx.Deadline(); ok {
_ = connection.SetDeadline(deadline)
return
}
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
}

View file

@ -9,10 +9,10 @@ import (
"strings"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
@ -387,6 +387,9 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
if !ready {
return record, nil
}
if err := d.personalizeGuest(ctx, record, *state); err != nil {
return d.failMachineStartup(ctx, record, err.Error())
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
return d.failMachineStartup(ctx, record, err.Error())

View file

@ -6,51 +6,10 @@ import (
"net"
"strconv"
"strings"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func waitForGuestReady(ctx context.Context, host string, ports []contracthost.MachinePort) error {
host = strings.TrimSpace(host)
if host == "" {
return fmt.Errorf("guest runtime host is required")
}
waitContext, cancel := context.WithTimeout(ctx, defaultGuestReadyTimeout)
defer cancel()
for _, port := range ports {
if err := waitForGuestPort(waitContext, host, port); err != nil {
return err
}
}
return nil
}
func waitForGuestPort(ctx context.Context, host string, port contracthost.MachinePort) error {
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
ticker := time.NewTicker(defaultGuestReadyPollInterval)
defer ticker.Stop()
var lastErr error
for {
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
ready, err := guestPortReady(probeCtx, host, port)
cancel()
if err == nil && ready {
return nil
}
lastErr = err
select {
case <-ctx.Done():
return fmt.Errorf("wait for guest port %q on %s: %w (last_err=%v)", port.Name, address, ctx.Err(), lastErr)
case <-ticker.C:
}
}
}
func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) {
host = strings.TrimSpace(host)
if host == "" {

View file

@ -14,10 +14,10 @@ import (
"testing"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
hoststore "github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type blockingPublishedPortStore struct {
@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
}
}
func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *testing.T) {
func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -421,15 +421,6 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
extraMachines: exhaustedMachineRelayRecords(),
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() {
_ = sshListener.Close()
}()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() {
_ = vncListener.Close()
}()
startedAt := time.Unix(1700000300, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
@ -448,7 +439,10 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error {
t.Fatalf("restore snapshot should not synchronously reconfigure guest identity, host=%q machine=%q guest_config=%#v", host, machineID, guestConfig)
return nil
}
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
kernelPath := filepath.Join(root, "artifact-kernel")
@ -509,7 +503,7 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
})
defer server.Close()
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{
MachineID: "restored-exhausted",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
@ -526,18 +520,20 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
},
},
})
if err == nil || !strings.Contains(err.Error(), "allocate relay ports for restored machine") {
t.Fatalf("RestoreSnapshot error = %v, want relay allocation failure", err)
if err != nil {
t.Fatalf("RestoreSnapshot returned error: %v", err)
}
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); !errors.Is(err, hoststore.ErrNotFound) {
t.Fatalf("restored system volume record should be deleted, get err = %v", err)
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
}
if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); !os.IsNotExist(err) {
t.Fatalf("restored system disk should be removed, stat err = %v", err)
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil {
t.Fatalf("restored system volume record should exist: %v", err)
}
if len(runtime.deleteCalls) != 1 {
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); err != nil {
t.Fatalf("restored system disk should exist: %v", err)
}
if len(runtime.deleteCalls) != 0 {
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
}
assertOperationCount(t, baseStore, 0)
}

View file

@ -12,10 +12,10 @@ import (
"strings"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
@ -332,6 +332,9 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
restoredDrivePaths[driveID] = volumePath
}
// Do not force vsock_override on restore: Firecracker rejects it for old
// snapshots without a vsock device, and the jailed /run path already
// relocates safely for snapshots created with the new vsock-backed guest.
loadSpec := firecracker.SnapshotLoadSpec{
ID: firecracker.MachineID(req.MachineID),
SnapshotPath: vmstateArtifact.LocalPath,
@ -349,27 +352,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
return nil, fmt.Errorf("restore boot: %w", err)
}
// Wait for guest to become ready
if err := waitForGuestReady(ctx, machineState.RuntimeHost, defaultMachinePorts()); err != nil {
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("wait for restored guest ready: %w", err)
}
if err := d.reconfigureGuestIdentity(ctx, machineState.RuntimeHost, req.MachineID, guestConfig); err != nil {
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("reconfigure restored guest identity: %w", err)
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost)
if err != nil {
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("read restored guest ssh host key: %w", err)
}
systemVolumeID := d.systemVolumeID(req.MachineID)
now := time.Now().UTC()
@ -419,38 +401,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
RuntimeHost: machineState.RuntimeHost,
TapDevice: machineState.TapName,
Ports: defaultMachinePorts(),
GuestSSHPublicKey: guestSSHPublicKey,
Phase: contracthost.MachinePhaseRunning,
GuestSSHPublicKey: "",
Phase: contracthost.MachinePhaseStarting,
PID: machineState.PID,
SocketPath: machineState.SocketPath,
CreatedAt: now,
StartedAt: machineState.StartedAt,
}
d.relayAllocMu.Lock()
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
var vncRelayPort uint16
if err == nil {
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
}
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelays(machineRecord.ID)
for _, restoredVolumeID := range restoredUserVolumeIDs {
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("allocate relay ports for restored machine: %w", err)
}
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
startedRelays := true
defer func() {
if startedRelays {
d.stopMachineRelays(machineRecord.ID)
}
}()
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
for _, restoredVolumeID := range restoredUserVolumeIDs {
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
@ -462,7 +419,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
return nil, err
}
startedRelays = false
clearOperation = true
return &contracthost.RestoreSnapshotResponse{
Machine: machineToContract(machineRecord),

View file

@ -38,6 +38,7 @@ func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) {
HostDevName: "fctap7",
},
},
VsockOverride: &VsockOverride{UDSPath: "/run/microagent-personalizer.vsock"},
})
if err != nil {
t.Fatalf("put snapshot load: %v", err)
@ -47,7 +48,7 @@ func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) {
t.Fatalf("request path mismatch: got %q want %q", gotPath, "/snapshot/load")
}
want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}]}"
want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}],\"vsock_override\":{\"uds_path\":\"/run/microagent-personalizer.vsock\"}}"
if gotBody != want {
t.Fatalf("request body mismatch:\n got: %s\nwant: %s", gotBody, want)
}

View file

@ -168,6 +168,74 @@ func TestConfigureMachineConfiguresMMDSBeforeStart(t *testing.T) {
}
}
func TestConfigureMachineConfiguresVsockBeforeStart(t *testing.T) {
var requests []capturedRequest
socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
requests = append(requests, capturedRequest{
Method: r.Method,
Path: r.URL.Path,
Body: string(body),
})
w.WriteHeader(http.StatusNoContent)
})
defer shutdown()
client := newAPIClient(socketPath)
spec := MachineSpec{
ID: "vm-3",
VCPUs: 1,
MemoryMiB: 512,
KernelImagePath: "/kernel",
RootFSPath: "/rootfs",
Vsock: &VsockSpec{
ID: "microagent-personalizer",
CID: 42,
Path: "/run/microagent-personalizer.vsock",
},
}
paths := machinePaths{JailedSerialLogPath: "/logs/serial.log"}
network := NetworkAllocation{
InterfaceID: defaultInterfaceID,
TapName: "fctap0",
GuestMAC: "06:00:ac:10:00:02",
}
if err := configureMachine(context.Background(), client, paths, spec, network); err != nil {
t.Fatalf("configure machine: %v", err)
}
gotPaths := make([]string, 0, len(requests))
for _, request := range requests {
gotPaths = append(gotPaths, request.Path)
}
wantPaths := []string{
"/machine-config",
"/boot-source",
"/drives/root_drive",
"/network-interfaces/net0",
"/entropy",
"/serial",
"/vsock",
"/actions",
}
if len(gotPaths) != len(wantPaths) {
t.Fatalf("request count mismatch: got %d want %d (%v)", len(gotPaths), len(wantPaths), gotPaths)
}
for i := range wantPaths {
if gotPaths[i] != wantPaths[i] {
t.Fatalf("request %d mismatch: got %q want %q", i, gotPaths[i], wantPaths[i])
}
}
if requests[6].Body != "{\"guest_cid\":42,\"uds_path\":\"/run/microagent-personalizer.vsock\"}" {
t.Fatalf("vsock body mismatch: got %q", requests[6].Body)
}
}
func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) {
t.Helper()

View file

@ -132,7 +132,7 @@ func stageMachineFiles(spec MachineSpec, paths machinePaths) (MachineSpec, error
if spec.Vsock != nil {
vsock := *spec.Vsock
vsock.Path = jailedVSockPath(spec)
vsock.Path = jailedVSockDevicePath(*spec.Vsock)
staged.Vsock = &vsock
}
@ -244,11 +244,8 @@ func waitForPIDFile(ctx context.Context, pidFilePath string) (int, error) {
}
}
func jailedVSockPath(spec MachineSpec) string {
if spec.Vsock == nil {
return ""
}
return path.Join(defaultVSockRunDir, filepath.Base(strings.TrimSpace(spec.Vsock.Path)))
func jailedVSockDevicePath(spec VsockSpec) string {
return path.Join(defaultVSockRunDir, filepath.Base(strings.TrimSpace(spec.Path)))
}
func linkMachineFile(source string, target string) error {

View file

@ -331,6 +331,11 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us
}
}
var vsockOverride *VsockOverride
if loadSpec.Vsock != nil {
vsockOverride = &VsockOverride{UDSPath: jailedVSockDevicePath(*loadSpec.Vsock)}
}
// Load snapshot (replaces the full configure+start sequence)
if err := client.PutSnapshotLoad(ctx, SnapshotLoadParams{
SnapshotPath: chrootStatePath,
@ -345,6 +350,7 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us
HostDevName: network.TapName,
},
},
VsockOverride: vsockOverride,
}); err != nil {
cleanup(network, paths, command, firecrackerPID)
return nil, fmt.Errorf("load snapshot: %w", err)
@ -369,6 +375,11 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us
return &state, nil
}
func (r *Runtime) PutMMDS(ctx context.Context, state MachineState, data any) error {
client := newAPIClient(state.SocketPath)
return client.PutMMDS(ctx, data)
}
func processExists(pid int) bool {
if pid < 1 {
return false