feat: simplify snapshot restore to disk boot

This commit is contained in:
Harivansh Rathi 2026-04-11 14:04:12 +00:00
parent 149bc2985a
commit 2ded10a67a
4 changed files with 113 additions and 283 deletions

View file

@ -3,7 +3,6 @@ package daemon
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"sort"
@ -13,7 +12,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/httpapi"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
@ -85,20 +83,6 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
return nil, fmt.Errorf("pause machine %q: %w", machineID, err)
}
// Write snapshot inside the chroot (Firecracker can only write there)
// Use jailed paths relative to the chroot root
chrootMemPath := "memory.bin"
chrootStatePath := "vmstate.bin"
if err := d.runtime.CreateSnapshot(ctx, runtimeState, firecracker.SnapshotPaths{
MemFilePath: chrootMemPath,
StateFilePath: chrootStatePath,
}); err != nil {
_ = d.runtime.Resume(ctx, runtimeState)
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("create snapshot for %q: %w", machineID, err)
}
// COW-copy disk files while paused for consistency
var diskPaths []string
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
@ -137,24 +121,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
return nil, fmt.Errorf("resume machine %q: %w", machineID, err)
}
// Copy snapshot files from chroot to snapshot directory, then remove originals.
// os.Rename fails across filesystem boundaries (/proc/<pid>/root/ is on procfs).
chrootRoot := filepath.Dir(filepath.Dir(runtimeState.SocketPath)) // strip /run/firecracker.socket
srcMemPath := filepath.Join(chrootRoot, chrootMemPath)
srcStatePath := filepath.Join(chrootRoot, chrootStatePath)
dstMemPath := filepath.Join(snapshotDir, "memory.bin")
dstStatePath := filepath.Join(snapshotDir, "vmstate.bin")
if err := moveFile(srcMemPath, dstMemPath); err != nil {
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("move memory file: %w", err)
}
if err := moveFile(srcStatePath, dstStatePath); err != nil {
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("move vmstate file: %w", err)
}
artifacts, err := buildSnapshotArtifacts(dstMemPath, dstStatePath, diskPaths)
artifacts, err := buildSnapshotArtifacts(diskPaths)
if err != nil {
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("build snapshot artifacts: %w", err)
@ -162,16 +129,12 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
now := time.Now().UTC()
snapshotRecord := model.SnapshotRecord{
ID: snapshotID,
MachineID: machineID,
Artifact: record.Artifact,
MemFilePath: dstMemPath,
StateFilePath: dstStatePath,
DiskPaths: diskPaths,
Artifacts: artifacts,
SourceRuntimeHost: record.RuntimeHost,
SourceTapDevice: record.TapDevice,
CreatedAt: now,
ID: snapshotID,
MachineID: machineID,
Artifact: record.Artifact,
DiskPaths: diskPaths,
Artifacts: artifacts,
CreatedAt: now,
}
if err := d.store.CreateSnapshot(ctx, snapshotRecord); err != nil {
_ = os.RemoveAll(snapshotDir)
@ -276,11 +239,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
}()
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
if err != nil {
return nil, err
}
restoredArtifacts, restoreNetwork, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req, usedNetworks)
restoredArtifacts, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req)
if err != nil {
clearOperation = true
return nil, err
@ -302,16 +261,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID)
}
memoryArtifact, ok := restoredArtifacts["memory.bin"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing memory artifact", snapshotID)
}
vmstateArtifact, ok := restoredArtifacts["vmstate.bin"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing vmstate artifact", snapshotID)
}
if err := cloneDiskFile(systemDiskPath.LocalPath, newSystemDiskPath, d.config.DiskCloneMode); err != nil {
clearOperation = true
return nil, fmt.Errorf("copy system disk for restore: %w", err)
@ -341,24 +290,31 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
restoredDrivePaths[driveID] = volumePath
}
// Do not force vsock_override on restore: Firecracker rejects it for old
// snapshots without a vsock device, and the jailed /run path already
// relocates safely for snapshots created with the new vsock-backed guest.
loadSpec := firecracker.SnapshotLoadSpec{
ID: firecracker.MachineID(req.MachineID),
SnapshotPath: vmstateArtifact.LocalPath,
MemFilePath: memoryArtifact.LocalPath,
RootFSPath: newSystemDiskPath,
KernelImagePath: artifact.KernelImagePath,
DiskPaths: restoredDrivePaths,
Network: &restoreNetwork,
userVolumes := make([]model.VolumeRecord, 0, len(restoredUserVolumes))
for _, volume := range restoredUserVolumes {
userVolumes = append(userVolumes, model.VolumeRecord{
ID: volume.ID,
Kind: contracthost.VolumeKindUser,
Path: volume.Path,
})
}
machineState, err := d.runtime.RestoreBoot(ctx, loadSpec, usedNetworks)
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("restore boot: %w", err)
return nil, fmt.Errorf("build machine spec for restore: %w", err)
}
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, err
}
machineState, err := d.runtime.Boot(ctx, spec, usedNetworks)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("boot restored machine: %w", err)
}
systemVolumeID := d.systemVolumeID(req.MachineID)
@ -556,86 +512,39 @@ func restoredUserDiskIndex(name string) (int, bool) {
return index, true
}
func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest, usedNetworks []firecracker.NetworkAllocation) (map[string]restoredSnapshotArtifact, firecracker.NetworkAllocation, func(), error) {
func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (map[string]restoredSnapshotArtifact, func(), error) {
if req.LocalSnapshot != nil {
if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
return nil, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
}
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil {
if err == store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host")
return nil, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host")
}
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
restoreNetwork, err := restoreNetworkFromSnapshot(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, fmt.Sprintf("restore network is still in use on this host (runtime_host=%s tap_device=%s)", restoreNetwork.GuestIP(), restoreNetwork.TapName))
return nil, func() {}, err
}
artifacts, err := localSnapshotArtifacts(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
return nil, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
return artifacts, restoreNetwork, func() {}, nil
return artifacts, func() {}, nil
}
if req.Snapshot == nil {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("durable snapshot spec is required")
}
restoreNetwork, err := restoreNetworkFromDurableSpec(*req.Snapshot)
if err != nil {
snapshot, lookupErr := d.store.GetSnapshot(ctx, snapshotID)
if lookupErr == nil {
restoreNetwork, err = restoreNetworkFromSnapshot(snapshot)
} else if lookupErr != store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, lookupErr
}
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
return nil, func() {}, fmt.Errorf("durable snapshot spec is required")
}
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil {
_ = os.RemoveAll(stagingDir)
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err)
return nil, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err)
}
return artifacts, restoreNetwork, func() {
return artifacts, func() {
_ = os.RemoveAll(stagingDir)
}, nil
}
func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
if strings.TrimSpace(spec.SourceRuntimeHost) == "" || strings.TrimSpace(spec.SourceTapDevice) == "" {
return firecracker.NetworkAllocation{}, fmt.Errorf("durable snapshot spec is missing restore network metadata")
}
network, err := firecracker.AllocationFromGuestIP(spec.SourceRuntimeHost, spec.SourceTapDevice)
if err != nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct durable snapshot %q network: %w", spec.SnapshotID, err)
}
return network, nil
}
func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) {
if snap == nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required")
}
if strings.TrimSpace(snap.SourceRuntimeHost) == "" || strings.TrimSpace(snap.SourceTapDevice) == "" {
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snap.ID)
}
network, err := firecracker.AllocationFromGuestIP(snap.SourceRuntimeHost, snap.SourceTapDevice)
if err != nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct snapshot %q network: %w", snap.ID, err)
}
return network, nil
}
func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) {
if snapshot == nil {
return nil, fmt.Errorf("snapshot is required")
@ -662,19 +571,6 @@ func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restored
return restored, nil
}
func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool {
targetTap := strings.TrimSpace(target.TapName)
for _, network := range used {
if network.GuestIP() == target.GuestIP() {
return true
}
if targetTap != "" && strings.TrimSpace(network.TapName) == targetTap {
return true
}
}
return false
}
func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error {
message = strings.TrimSpace(message)
if message == "" {
@ -682,31 +578,3 @@ func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message
}
return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message)
}
// moveFile copies src to dst then removes src. Works across filesystem boundaries
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
func moveFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer func() {
_ = in.Close()
}()
out, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(out, in); err != nil {
_ = out.Close()
_ = os.Remove(dst)
return err
}
if err := out.Close(); err != nil {
_ = os.Remove(dst)
return err
}
return os.Remove(src)
}