From 2ded10a67adfadc4e568c30e4c234971cb6c8873 Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Sat, 11 Apr 2026 14:04:12 +0000 Subject: [PATCH] feat: simplify snapshot restore to disk boot --- internal/daemon/daemon_test.go | 136 +++++++------- internal/daemon/review_regressions_test.go | 35 +--- internal/daemon/snapshot.go | 208 ++++----------------- internal/daemon/snapshot_transfer.go | 17 +- 4 files changed, 113 insertions(+), 283 deletions(-) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index a510bc7..3e2ce6e 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -434,11 +434,9 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } server := newRestoreArtifactServer(t, map[string][]byte{ - "/kernel": []byte("kernel"), - "/rootfs": []byte("rootfs"), - "/memory": []byte("mem"), - "/vmstate": []byte("state"), - "/system": []byte("disk"), + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/system": []byte("disk"), }) defer server.Close() @@ -465,8 +463,6 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { ID: "snap1", MachineID: "source", Artifact: artifactRef, - MemFilePath: filepath.Join(root, "snapshots", "snap1", "memory.bin"), - StateFilePath: filepath.Join(root, "snapshots", "snap1", "vmstate.bin"), DiskPaths: []string{filepath.Join(root, "snapshots", "snap1", "system.img")}, SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", @@ -486,8 +482,6 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { MachineID: "source", ImageID: "image-1", Artifacts: []contracthost.SnapshotArtifact{ - {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"}, - {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, }, }, @@ -502,17 +496,11 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { if response.Machine.Phase != contracthost.MachinePhaseStarting { t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) } - if runtime.restoreCalls != 1 { - t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) + if runtime.bootCalls != 1 { + t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls) } - if runtime.lastLoadSpec.Network == nil { - t.Fatalf("restore boot should preserve snapshot network") - } - if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { - t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2") - } - if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" { - t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0") + if runtime.restoreCalls != 0 { + t.Fatalf("restore boot call count mismatch: got %d want 0", runtime.restoreCalls) } ops, err := fileStore.ListOperations(context.Background()) @@ -593,28 +581,16 @@ func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) { if err := os.MkdirAll(snapshotDir, 0o755); err != nil { t.Fatalf("create snapshot dir: %v", err) } - memoryPath := filepath.Join(snapshotDir, "memory.bin") - vmstatePath := filepath.Join(snapshotDir, "vmstate.bin") systemPath := filepath.Join(snapshotDir, "system.img") - if err := os.WriteFile(memoryPath, []byte("mem"), 0o644); err != nil { - t.Fatalf("write memory: %v", err) - } - if err := os.WriteFile(vmstatePath, []byte("state"), 0o644); err != nil { - t.Fatalf("write vmstate: %v", err) - } if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil { t.Fatalf("write system disk: %v", err) } if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ - ID: "snap-local", - MachineID: "source", - Artifact: artifactRef, - MemFilePath: memoryPath, - StateFilePath: vmstatePath, - DiskPaths: []string{systemPath}, + ID: "snap-local", + MachineID: "source", + Artifact: artifactRef, + DiskPaths: []string{systemPath}, Artifacts: []model.SnapshotArtifactRecord{ - {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", LocalPath: memoryPath, SizeBytes: 3}, - {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", LocalPath: vmstatePath, SizeBytes: 5}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4}, }, SourceRuntimeHost: "172.16.0.2", @@ -638,17 +614,11 @@ func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) { if response.Machine.ID != "restored-local" { t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) } - if runtime.restoreCalls != 1 { - t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) + if runtime.bootCalls != 1 { + t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls) } - if runtime.lastLoadSpec.Network == nil { - t.Fatalf("restore boot should preserve local snapshot network") - } - if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { - t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2") - } - if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" { - t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0") + if runtime.restoreCalls != 0 { + t.Fatalf("restore boot call count mismatch: got %d want 0", runtime.restoreCalls) } } @@ -741,12 +711,10 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { } server := newRestoreArtifactServer(t, map[string][]byte{ - "/kernel": []byte("kernel"), - "/rootfs": []byte("rootfs"), - "/memory": []byte("mem"), - "/vmstate": []byte("state"), - "/system": []byte("disk"), - "/user-0": []byte("user-disk"), + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/system": []byte("disk"), + "/user-0": []byte("user-disk"), }) defer server.Close() @@ -763,8 +731,6 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", Artifacts: []contracthost.SnapshotArtifact{ - {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"}, - {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, {ID: "disk-user-0", Kind: contracthost.SnapshotArtifactKindDisk, Name: "user-0.img", DownloadURL: server.URL + "/user-0"}, }, @@ -780,23 +746,17 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { if response.Machine.Phase != contracthost.MachinePhaseStarting { t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) } - if runtime.restoreCalls != 1 { - t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) + if runtime.bootCalls != 1 { + t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls) } - if runtime.lastLoadSpec.Network == nil { - t.Fatalf("restore boot should preserve durable snapshot network") + if runtime.restoreCalls != 0 { + t.Fatalf("restore boot call count mismatch: got %d want 0", runtime.restoreCalls) } - if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { - t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2") - } - if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" { - t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0") - } - if !strings.Contains(runtime.lastLoadSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{ + if !strings.Contains(runtime.lastSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{ KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }), "kernel")) { - t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath) + t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastSpec.KernelImagePath) } machine, err := fileStore.GetMachine(context.Background(), "restored") @@ -825,7 +785,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { } } -func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { +func TestRestoreSnapshotBootsWithFreshNetworkWhenSourceNetworkInUseOnHost(t *testing.T) { root := t.TempDir() cfg := testConfig(root) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -838,6 +798,31 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } + stubGuestSSHPublicKeyReader(hostDaemon) + hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { _ = sshListener.Close() }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { _ = vncListener.Close() }() + + startedAt := time.Unix(1700000299, 0).UTC() + runtime.bootState = firecracker.MachineState{ + ID: "restored", + Phase: firecracker.PhaseRunning, + PID: 1234, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored", "root", "run", "firecracker.sock"), + TapName: "fctap9", + StartedAt: &startedAt, + } + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/system": []byte("disk"), + }) + defer server.Close() if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ ID: "source", @@ -851,11 +836,11 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { t.Fatalf("create running source machine: %v", err) } - _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ + response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ MachineID: "restored", Artifact: contracthost.ArtifactRef{ - KernelImageURL: "https://example.com/kernel", - RootFSURL: "https://example.com/rootfs", + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", }, Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap1", @@ -863,10 +848,19 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { ImageID: "image-1", SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, + }, }, }) - if err == nil || !strings.Contains(err.Error(), "still in use on this host") { - t.Fatalf("restore snapshot error = %v, want restore network in-use failure", err) + if err != nil { + t.Fatalf("restore snapshot error = %v, want success", err) + } + if response.Machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) + } + if runtime.bootCalls != 1 { + t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls) } if runtime.restoreCalls != 0 { t.Fatalf("restore boot should not be attempted, got %d calls", runtime.restoreCalls) diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 5322097..af9edd2 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -472,20 +472,10 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { t.Fatalf("write snapshot disk: %v", err) } - memPath := filepath.Join(snapDir, "memory.bin") - if err := os.WriteFile(memPath, []byte("mem"), 0o644); err != nil { - t.Fatalf("write memory snapshot: %v", err) - } - statePath := filepath.Join(snapDir, "vmstate.bin") - if err := os.WriteFile(statePath, []byte("state"), 0o644); err != nil { - t.Fatalf("write vmstate snapshot: %v", err) - } if err := baseStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ ID: "snap-exhausted", MachineID: "source", Artifact: artifactRef, - MemFilePath: memPath, - StateFilePath: statePath, DiskPaths: []string{snapDisk}, SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", @@ -495,11 +485,9 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T } server := newRestoreArtifactServer(t, map[string][]byte{ - "/kernel": []byte("kernel"), - "/rootfs": []byte("rootfs"), - "/memory": []byte("mem"), - "/vmstate": []byte("state"), - "/system": []byte("disk"), + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/system": []byte("disk"), }) defer server.Close() @@ -514,8 +502,6 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T MachineID: "source", ImageID: "image-1", Artifacts: []contracthost.SnapshotArtifact{ - {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"}, - {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, }, }, @@ -694,11 +680,9 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) { hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } server := newRestoreArtifactServer(t, map[string][]byte{ - "/kernel": []byte("kernel"), - "/rootfs": []byte("rootfs"), - "/memory": []byte("mem"), - "/vmstate": []byte("state"), - "/system": []byte("disk"), + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/system": []byte("disk"), }) defer server.Close() @@ -715,8 +699,6 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) { SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", Artifacts: []contracthost.SnapshotArtifact{ - {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))}, - {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate", SHA256Hex: mustSHA256Hex(t, []byte("state"))}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system", SHA256Hex: mustSHA256Hex(t, []byte("disk"))}, }, }, @@ -749,7 +731,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) server := newRestoreArtifactServer(t, map[string][]byte{ "/kernel": []byte("kernel"), "/rootfs": []byte("rootfs"), - "/memory": []byte("mem"), + "/system": []byte("disk"), }) defer server.Close() @@ -766,8 +748,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", Artifacts: []contracthost.SnapshotArtifact{ - {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))}, - {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/missing"}, + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/missing"}, }, }, }) diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index ae72130..4f04616 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -3,7 +3,6 @@ package daemon import ( "context" "fmt" - "io" "os" "path/filepath" "sort" @@ -13,7 +12,6 @@ import ( "golang.org/x/sync/errgroup" - "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/httpapi" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" @@ -85,20 +83,6 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach return nil, fmt.Errorf("pause machine %q: %w", machineID, err) } - // Write snapshot inside the chroot (Firecracker can only write there) - // Use jailed paths relative to the chroot root - chrootMemPath := "memory.bin" - chrootStatePath := "vmstate.bin" - - if err := d.runtime.CreateSnapshot(ctx, runtimeState, firecracker.SnapshotPaths{ - MemFilePath: chrootMemPath, - StateFilePath: chrootStatePath, - }); err != nil { - _ = d.runtime.Resume(ctx, runtimeState) - _ = os.RemoveAll(snapshotDir) - return nil, fmt.Errorf("create snapshot for %q: %w", machineID, err) - } - // COW-copy disk files while paused for consistency var diskPaths []string systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID) @@ -137,24 +121,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach return nil, fmt.Errorf("resume machine %q: %w", machineID, err) } - // Copy snapshot files from chroot to snapshot directory, then remove originals. - // os.Rename fails across filesystem boundaries (/proc//root/ is on procfs). - chrootRoot := filepath.Dir(filepath.Dir(runtimeState.SocketPath)) // strip /run/firecracker.socket - srcMemPath := filepath.Join(chrootRoot, chrootMemPath) - srcStatePath := filepath.Join(chrootRoot, chrootStatePath) - dstMemPath := filepath.Join(snapshotDir, "memory.bin") - dstStatePath := filepath.Join(snapshotDir, "vmstate.bin") - - if err := moveFile(srcMemPath, dstMemPath); err != nil { - _ = os.RemoveAll(snapshotDir) - return nil, fmt.Errorf("move memory file: %w", err) - } - if err := moveFile(srcStatePath, dstStatePath); err != nil { - _ = os.RemoveAll(snapshotDir) - return nil, fmt.Errorf("move vmstate file: %w", err) - } - - artifacts, err := buildSnapshotArtifacts(dstMemPath, dstStatePath, diskPaths) + artifacts, err := buildSnapshotArtifacts(diskPaths) if err != nil { _ = os.RemoveAll(snapshotDir) return nil, fmt.Errorf("build snapshot artifacts: %w", err) @@ -162,16 +129,12 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach now := time.Now().UTC() snapshotRecord := model.SnapshotRecord{ - ID: snapshotID, - MachineID: machineID, - Artifact: record.Artifact, - MemFilePath: dstMemPath, - StateFilePath: dstStatePath, - DiskPaths: diskPaths, - Artifacts: artifacts, - SourceRuntimeHost: record.RuntimeHost, - SourceTapDevice: record.TapDevice, - CreatedAt: now, + ID: snapshotID, + MachineID: machineID, + Artifact: record.Artifact, + DiskPaths: diskPaths, + Artifacts: artifacts, + CreatedAt: now, } if err := d.store.CreateSnapshot(ctx, snapshotRecord); err != nil { _ = os.RemoveAll(snapshotDir) @@ -276,11 +239,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } }() - usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) - if err != nil { - return nil, err - } - restoredArtifacts, restoreNetwork, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req, usedNetworks) + restoredArtifacts, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req) if err != nil { clearOperation = true return nil, err @@ -302,16 +261,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn clearOperation = true return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID) } - memoryArtifact, ok := restoredArtifacts["memory.bin"] - if !ok { - clearOperation = true - return nil, fmt.Errorf("snapshot %q is missing memory artifact", snapshotID) - } - vmstateArtifact, ok := restoredArtifacts["vmstate.bin"] - if !ok { - clearOperation = true - return nil, fmt.Errorf("snapshot %q is missing vmstate artifact", snapshotID) - } if err := cloneDiskFile(systemDiskPath.LocalPath, newSystemDiskPath, d.config.DiskCloneMode); err != nil { clearOperation = true return nil, fmt.Errorf("copy system disk for restore: %w", err) @@ -341,24 +290,31 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn restoredDrivePaths[driveID] = volumePath } - // Do not force vsock_override on restore: Firecracker rejects it for old - // snapshots without a vsock device, and the jailed /run path already - // relocates safely for snapshots created with the new vsock-backed guest. - loadSpec := firecracker.SnapshotLoadSpec{ - ID: firecracker.MachineID(req.MachineID), - SnapshotPath: vmstateArtifact.LocalPath, - MemFilePath: memoryArtifact.LocalPath, - RootFSPath: newSystemDiskPath, - KernelImagePath: artifact.KernelImagePath, - DiskPaths: restoredDrivePaths, - Network: &restoreNetwork, + userVolumes := make([]model.VolumeRecord, 0, len(restoredUserVolumes)) + for _, volume := range restoredUserVolumes { + userVolumes = append(userVolumes, model.VolumeRecord{ + ID: volume.ID, + Kind: contracthost.VolumeKindUser, + Path: volume.Path, + }) } - - machineState, err := d.runtime.RestoreBoot(ctx, loadSpec, usedNetworks) + spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig) if err != nil { _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true - return nil, fmt.Errorf("restore boot: %w", err) + return nil, fmt.Errorf("build machine spec for restore: %w", err) + } + usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) + if err != nil { + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, err + } + machineState, err := d.runtime.Boot(ctx, spec, usedNetworks) + if err != nil { + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("boot restored machine: %w", err) } systemVolumeID := d.systemVolumeID(req.MachineID) @@ -556,86 +512,39 @@ func restoredUserDiskIndex(name string) (int, bool) { return index, true } -func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest, usedNetworks []firecracker.NetworkAllocation) (map[string]restoredSnapshotArtifact, firecracker.NetworkAllocation, func(), error) { +func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (map[string]restoredSnapshotArtifact, func(), error) { if req.LocalSnapshot != nil { if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID { - return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID) + return nil, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID) } snapshot, err := d.store.GetSnapshot(ctx, snapshotID) if err != nil { if err == store.ErrNotFound { - return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host") + return nil, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host") } - return nil, firecracker.NetworkAllocation{}, func() {}, err - } - restoreNetwork, err := restoreNetworkFromSnapshot(snapshot) - if err != nil { - return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error()) - } - if networkAllocationInUse(restoreNetwork, usedNetworks) { - return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, fmt.Sprintf("restore network is still in use on this host (runtime_host=%s tap_device=%s)", restoreNetwork.GuestIP(), restoreNetwork.TapName)) + return nil, func() {}, err } artifacts, err := localSnapshotArtifacts(snapshot) if err != nil { - return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error()) + return nil, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error()) } - return artifacts, restoreNetwork, func() {}, nil + return artifacts, func() {}, nil } if req.Snapshot == nil { - return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("durable snapshot spec is required") - } - restoreNetwork, err := restoreNetworkFromDurableSpec(*req.Snapshot) - if err != nil { - snapshot, lookupErr := d.store.GetSnapshot(ctx, snapshotID) - if lookupErr == nil { - restoreNetwork, err = restoreNetworkFromSnapshot(snapshot) - } else if lookupErr != store.ErrNotFound { - return nil, firecracker.NetworkAllocation{}, func() {}, lookupErr - } - if err != nil { - return nil, firecracker.NetworkAllocation{}, func() {}, err - } - } - if networkAllocationInUse(restoreNetwork, usedNetworks) { - return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName) + return nil, func() {}, fmt.Errorf("durable snapshot spec is required") } stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID)) artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts) if err != nil { _ = os.RemoveAll(stagingDir) - return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err) + return nil, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err) } - return artifacts, restoreNetwork, func() { + return artifacts, func() { _ = os.RemoveAll(stagingDir) }, nil } -func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) { - if strings.TrimSpace(spec.SourceRuntimeHost) == "" || strings.TrimSpace(spec.SourceTapDevice) == "" { - return firecracker.NetworkAllocation{}, fmt.Errorf("durable snapshot spec is missing restore network metadata") - } - network, err := firecracker.AllocationFromGuestIP(spec.SourceRuntimeHost, spec.SourceTapDevice) - if err != nil { - return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct durable snapshot %q network: %w", spec.SnapshotID, err) - } - return network, nil -} - -func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) { - if snap == nil { - return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required") - } - if strings.TrimSpace(snap.SourceRuntimeHost) == "" || strings.TrimSpace(snap.SourceTapDevice) == "" { - return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snap.ID) - } - network, err := firecracker.AllocationFromGuestIP(snap.SourceRuntimeHost, snap.SourceTapDevice) - if err != nil { - return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct snapshot %q network: %w", snap.ID, err) - } - return network, nil -} - func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) { if snapshot == nil { return nil, fmt.Errorf("snapshot is required") @@ -662,19 +571,6 @@ func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restored return restored, nil } -func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool { - targetTap := strings.TrimSpace(target.TapName) - for _, network := range used { - if network.GuestIP() == target.GuestIP() { - return true - } - if targetTap != "" && strings.TrimSpace(network.TapName) == targetTap { - return true - } - } - return false -} - func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error { message = strings.TrimSpace(message) if message == "" { @@ -682,31 +578,3 @@ func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message } return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message) } - -// moveFile copies src to dst then removes src. Works across filesystem boundaries -// unlike os.Rename, which is needed when moving files out of /proc//root/. -func moveFile(src, dst string) error { - in, err := os.Open(src) - if err != nil { - return err - } - defer func() { - _ = in.Close() - }() - - out, err := os.Create(dst) - if err != nil { - return err - } - - if _, err := io.Copy(out, in); err != nil { - _ = out.Close() - _ = os.Remove(dst) - return err - } - if err := out.Close(); err != nil { - _ = os.Remove(dst) - return err - } - return os.Remove(src) -} diff --git a/internal/daemon/snapshot_transfer.go b/internal/daemon/snapshot_transfer.go index 6656b4d..a3c4a74 100644 --- a/internal/daemon/snapshot_transfer.go +++ b/internal/daemon/snapshot_transfer.go @@ -21,21 +21,8 @@ type restoredSnapshotArtifact struct { LocalPath string } -func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) { - artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2) - - memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath) - if err != nil { - return nil, err - } - artifacts = append(artifacts, memoryArtifact) - - vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath) - if err != nil { - return nil, err - } - artifacts = append(artifacts, vmstateArtifact) - +func buildSnapshotArtifacts(diskPaths []string) ([]model.SnapshotArtifactRecord, error) { + artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)) for _, diskPath := range diskPaths { base := filepath.Base(diskPath) diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath)