feat: simplify snapshot restore to disk boot

This commit is contained in:
Harivansh Rathi 2026-04-11 14:04:12 +00:00
parent 149bc2985a
commit 2ded10a67a
4 changed files with 113 additions and 283 deletions

View file

@ -434,11 +434,9 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
server := newRestoreArtifactServer(t, map[string][]byte{ server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"), "/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"), "/rootfs": []byte("rootfs"),
"/memory": []byte("mem"), "/system": []byte("disk"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
}) })
defer server.Close() defer server.Close()
@ -465,8 +463,6 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
ID: "snap1", ID: "snap1",
MachineID: "source", MachineID: "source",
Artifact: artifactRef, Artifact: artifactRef,
MemFilePath: filepath.Join(root, "snapshots", "snap1", "memory.bin"),
StateFilePath: filepath.Join(root, "snapshots", "snap1", "vmstate.bin"),
DiskPaths: []string{filepath.Join(root, "snapshots", "snap1", "system.img")}, DiskPaths: []string{filepath.Join(root, "snapshots", "snap1", "system.img")},
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0", SourceTapDevice: "fctap0",
@ -486,8 +482,6 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
MachineID: "source", MachineID: "source",
ImageID: "image-1", ImageID: "image-1",
Artifacts: []contracthost.SnapshotArtifact{ Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
}, },
}, },
@ -502,17 +496,11 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
if response.Machine.Phase != contracthost.MachinePhaseStarting { if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
} }
if runtime.restoreCalls != 1 { if runtime.bootCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls)
} }
if runtime.lastLoadSpec.Network == nil { if runtime.restoreCalls != 0 {
t.Fatalf("restore boot should preserve snapshot network") t.Fatalf("restore boot call count mismatch: got %d want 0", runtime.restoreCalls)
}
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
}
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
} }
ops, err := fileStore.ListOperations(context.Background()) ops, err := fileStore.ListOperations(context.Background())
@ -593,28 +581,16 @@ func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) {
if err := os.MkdirAll(snapshotDir, 0o755); err != nil { if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err) t.Fatalf("create snapshot dir: %v", err)
} }
memoryPath := filepath.Join(snapshotDir, "memory.bin")
vmstatePath := filepath.Join(snapshotDir, "vmstate.bin")
systemPath := filepath.Join(snapshotDir, "system.img") systemPath := filepath.Join(snapshotDir, "system.img")
if err := os.WriteFile(memoryPath, []byte("mem"), 0o644); err != nil {
t.Fatalf("write memory: %v", err)
}
if err := os.WriteFile(vmstatePath, []byte("state"), 0o644); err != nil {
t.Fatalf("write vmstate: %v", err)
}
if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil { if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil {
t.Fatalf("write system disk: %v", err) t.Fatalf("write system disk: %v", err)
} }
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap-local", ID: "snap-local",
MachineID: "source", MachineID: "source",
Artifact: artifactRef, Artifact: artifactRef,
MemFilePath: memoryPath, DiskPaths: []string{systemPath},
StateFilePath: vmstatePath,
DiskPaths: []string{systemPath},
Artifacts: []model.SnapshotArtifactRecord{ Artifacts: []model.SnapshotArtifactRecord{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", LocalPath: memoryPath, SizeBytes: 3},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", LocalPath: vmstatePath, SizeBytes: 5},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4},
}, },
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
@ -638,17 +614,11 @@ func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) {
if response.Machine.ID != "restored-local" { if response.Machine.ID != "restored-local" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
} }
if runtime.restoreCalls != 1 { if runtime.bootCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls)
} }
if runtime.lastLoadSpec.Network == nil { if runtime.restoreCalls != 0 {
t.Fatalf("restore boot should preserve local snapshot network") t.Fatalf("restore boot call count mismatch: got %d want 0", runtime.restoreCalls)
}
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
}
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
} }
} }
@ -741,12 +711,10 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
} }
server := newRestoreArtifactServer(t, map[string][]byte{ server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"), "/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"), "/rootfs": []byte("rootfs"),
"/memory": []byte("mem"), "/system": []byte("disk"),
"/vmstate": []byte("state"), "/user-0": []byte("user-disk"),
"/system": []byte("disk"),
"/user-0": []byte("user-disk"),
}) })
defer server.Close() defer server.Close()
@ -763,8 +731,6 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0", SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{ Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
{ID: "disk-user-0", Kind: contracthost.SnapshotArtifactKindDisk, Name: "user-0.img", DownloadURL: server.URL + "/user-0"}, {ID: "disk-user-0", Kind: contracthost.SnapshotArtifactKindDisk, Name: "user-0.img", DownloadURL: server.URL + "/user-0"},
}, },
@ -780,23 +746,17 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
if response.Machine.Phase != contracthost.MachinePhaseStarting { if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
} }
if runtime.restoreCalls != 1 { if runtime.bootCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls)
} }
if runtime.lastLoadSpec.Network == nil { if runtime.restoreCalls != 0 {
t.Fatalf("restore boot should preserve durable snapshot network") t.Fatalf("restore boot call count mismatch: got %d want 0", runtime.restoreCalls)
} }
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { if !strings.Contains(runtime.lastSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
}
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
}
if !strings.Contains(runtime.lastLoadSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel", KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs", RootFSURL: server.URL + "/rootfs",
}), "kernel")) { }), "kernel")) {
t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath) t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastSpec.KernelImagePath)
} }
machine, err := fileStore.GetMachine(context.Background(), "restored") machine, err := fileStore.GetMachine(context.Background(), "restored")
@ -825,7 +785,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
} }
} }
func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { func TestRestoreSnapshotBootsWithFreshNetworkWhenSourceNetworkInUseOnHost(t *testing.T) {
root := t.TempDir() root := t.TempDir()
cfg := testConfig(root) cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -838,6 +798,31 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("create daemon: %v", err) t.Fatalf("create daemon: %v", err)
} }
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() { _ = sshListener.Close() }()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() { _ = vncListener.Close() }()
startedAt := time.Unix(1700000299, 0).UTC()
runtime.bootState = firecracker.MachineState{
ID: "restored",
Phase: firecracker.PhaseRunning,
PID: 1234,
RuntimeHost: "127.0.0.1",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored", "root", "run", "firecracker.sock"),
TapName: "fctap9",
StartedAt: &startedAt,
}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/system": []byte("disk"),
})
defer server.Close()
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "source", ID: "source",
@ -851,11 +836,11 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) {
t.Fatalf("create running source machine: %v", err) t.Fatalf("create running source machine: %v", err)
} }
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
MachineID: "restored", MachineID: "restored",
Artifact: contracthost.ArtifactRef{ Artifact: contracthost.ArtifactRef{
KernelImageURL: "https://example.com/kernel", KernelImageURL: server.URL + "/kernel",
RootFSURL: "https://example.com/rootfs", RootFSURL: server.URL + "/rootfs",
}, },
Snapshot: &contracthost.DurableSnapshotSpec{ Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap1", SnapshotID: "snap1",
@ -863,10 +848,19 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) {
ImageID: "image-1", ImageID: "image-1",
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0", SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
},
}, },
}) })
if err == nil || !strings.Contains(err.Error(), "still in use on this host") { if err != nil {
t.Fatalf("restore snapshot error = %v, want restore network in-use failure", err) t.Fatalf("restore snapshot error = %v, want success", err)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
}
if runtime.bootCalls != 1 {
t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls)
} }
if runtime.restoreCalls != 0 { if runtime.restoreCalls != 0 {
t.Fatalf("restore boot should not be attempted, got %d calls", runtime.restoreCalls) t.Fatalf("restore boot should not be attempted, got %d calls", runtime.restoreCalls)

View file

@ -472,20 +472,10 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
t.Fatalf("write snapshot disk: %v", err) t.Fatalf("write snapshot disk: %v", err)
} }
memPath := filepath.Join(snapDir, "memory.bin")
if err := os.WriteFile(memPath, []byte("mem"), 0o644); err != nil {
t.Fatalf("write memory snapshot: %v", err)
}
statePath := filepath.Join(snapDir, "vmstate.bin")
if err := os.WriteFile(statePath, []byte("state"), 0o644); err != nil {
t.Fatalf("write vmstate snapshot: %v", err)
}
if err := baseStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ if err := baseStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap-exhausted", ID: "snap-exhausted",
MachineID: "source", MachineID: "source",
Artifact: artifactRef, Artifact: artifactRef,
MemFilePath: memPath,
StateFilePath: statePath,
DiskPaths: []string{snapDisk}, DiskPaths: []string{snapDisk},
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0", SourceTapDevice: "fctap0",
@ -495,11 +485,9 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
} }
server := newRestoreArtifactServer(t, map[string][]byte{ server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"), "/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"), "/rootfs": []byte("rootfs"),
"/memory": []byte("mem"), "/system": []byte("disk"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
}) })
defer server.Close() defer server.Close()
@ -514,8 +502,6 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
MachineID: "source", MachineID: "source",
ImageID: "image-1", ImageID: "image-1",
Artifacts: []contracthost.SnapshotArtifact{ Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
}, },
}, },
@ -694,11 +680,9 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) {
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
server := newRestoreArtifactServer(t, map[string][]byte{ server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"), "/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"), "/rootfs": []byte("rootfs"),
"/memory": []byte("mem"), "/system": []byte("disk"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
}) })
defer server.Close() defer server.Close()
@ -715,8 +699,6 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) {
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0", SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{ Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate", SHA256Hex: mustSHA256Hex(t, []byte("state"))},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system", SHA256Hex: mustSHA256Hex(t, []byte("disk"))}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system", SHA256Hex: mustSHA256Hex(t, []byte("disk"))},
}, },
}, },
@ -749,7 +731,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T)
server := newRestoreArtifactServer(t, map[string][]byte{ server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"), "/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"), "/rootfs": []byte("rootfs"),
"/memory": []byte("mem"), "/system": []byte("disk"),
}) })
defer server.Close() defer server.Close()
@ -766,8 +748,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T)
SourceRuntimeHost: "172.16.0.2", SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0", SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{ Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))}, {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/missing"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/missing"},
}, },
}, },
}) })

View file

@ -3,7 +3,6 @@ package daemon
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@ -13,7 +12,6 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/httpapi" "github.com/getcompanion-ai/computer-host/internal/httpapi"
"github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store" "github.com/getcompanion-ai/computer-host/internal/store"
@ -85,20 +83,6 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
return nil, fmt.Errorf("pause machine %q: %w", machineID, err) return nil, fmt.Errorf("pause machine %q: %w", machineID, err)
} }
// Write snapshot inside the chroot (Firecracker can only write there)
// Use jailed paths relative to the chroot root
chrootMemPath := "memory.bin"
chrootStatePath := "vmstate.bin"
if err := d.runtime.CreateSnapshot(ctx, runtimeState, firecracker.SnapshotPaths{
MemFilePath: chrootMemPath,
StateFilePath: chrootStatePath,
}); err != nil {
_ = d.runtime.Resume(ctx, runtimeState)
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("create snapshot for %q: %w", machineID, err)
}
// COW-copy disk files while paused for consistency // COW-copy disk files while paused for consistency
var diskPaths []string var diskPaths []string
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID) systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
@ -137,24 +121,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
return nil, fmt.Errorf("resume machine %q: %w", machineID, err) return nil, fmt.Errorf("resume machine %q: %w", machineID, err)
} }
// Copy snapshot files from chroot to snapshot directory, then remove originals. artifacts, err := buildSnapshotArtifacts(diskPaths)
// os.Rename fails across filesystem boundaries (/proc/<pid>/root/ is on procfs).
chrootRoot := filepath.Dir(filepath.Dir(runtimeState.SocketPath)) // strip /run/firecracker.socket
srcMemPath := filepath.Join(chrootRoot, chrootMemPath)
srcStatePath := filepath.Join(chrootRoot, chrootStatePath)
dstMemPath := filepath.Join(snapshotDir, "memory.bin")
dstStatePath := filepath.Join(snapshotDir, "vmstate.bin")
if err := moveFile(srcMemPath, dstMemPath); err != nil {
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("move memory file: %w", err)
}
if err := moveFile(srcStatePath, dstStatePath); err != nil {
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("move vmstate file: %w", err)
}
artifacts, err := buildSnapshotArtifacts(dstMemPath, dstStatePath, diskPaths)
if err != nil { if err != nil {
_ = os.RemoveAll(snapshotDir) _ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("build snapshot artifacts: %w", err) return nil, fmt.Errorf("build snapshot artifacts: %w", err)
@ -162,16 +129,12 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
now := time.Now().UTC() now := time.Now().UTC()
snapshotRecord := model.SnapshotRecord{ snapshotRecord := model.SnapshotRecord{
ID: snapshotID, ID: snapshotID,
MachineID: machineID, MachineID: machineID,
Artifact: record.Artifact, Artifact: record.Artifact,
MemFilePath: dstMemPath, DiskPaths: diskPaths,
StateFilePath: dstStatePath, Artifacts: artifacts,
DiskPaths: diskPaths, CreatedAt: now,
Artifacts: artifacts,
SourceRuntimeHost: record.RuntimeHost,
SourceTapDevice: record.TapDevice,
CreatedAt: now,
} }
if err := d.store.CreateSnapshot(ctx, snapshotRecord); err != nil { if err := d.store.CreateSnapshot(ctx, snapshotRecord); err != nil {
_ = os.RemoveAll(snapshotDir) _ = os.RemoveAll(snapshotDir)
@ -276,11 +239,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
} }
}() }()
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) restoredArtifacts, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req)
if err != nil {
return nil, err
}
restoredArtifacts, restoreNetwork, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req, usedNetworks)
if err != nil { if err != nil {
clearOperation = true clearOperation = true
return nil, err return nil, err
@ -302,16 +261,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
clearOperation = true clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID) return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID)
} }
memoryArtifact, ok := restoredArtifacts["memory.bin"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing memory artifact", snapshotID)
}
vmstateArtifact, ok := restoredArtifacts["vmstate.bin"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing vmstate artifact", snapshotID)
}
if err := cloneDiskFile(systemDiskPath.LocalPath, newSystemDiskPath, d.config.DiskCloneMode); err != nil { if err := cloneDiskFile(systemDiskPath.LocalPath, newSystemDiskPath, d.config.DiskCloneMode); err != nil {
clearOperation = true clearOperation = true
return nil, fmt.Errorf("copy system disk for restore: %w", err) return nil, fmt.Errorf("copy system disk for restore: %w", err)
@ -341,24 +290,31 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
restoredDrivePaths[driveID] = volumePath restoredDrivePaths[driveID] = volumePath
} }
// Do not force vsock_override on restore: Firecracker rejects it for old userVolumes := make([]model.VolumeRecord, 0, len(restoredUserVolumes))
// snapshots without a vsock device, and the jailed /run path already for _, volume := range restoredUserVolumes {
// relocates safely for snapshots created with the new vsock-backed guest. userVolumes = append(userVolumes, model.VolumeRecord{
loadSpec := firecracker.SnapshotLoadSpec{ ID: volume.ID,
ID: firecracker.MachineID(req.MachineID), Kind: contracthost.VolumeKindUser,
SnapshotPath: vmstateArtifact.LocalPath, Path: volume.Path,
MemFilePath: memoryArtifact.LocalPath, })
RootFSPath: newSystemDiskPath,
KernelImagePath: artifact.KernelImagePath,
DiskPaths: restoredDrivePaths,
Network: &restoreNetwork,
} }
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig)
machineState, err := d.runtime.RestoreBoot(ctx, loadSpec, usedNetworks)
if err != nil { if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) _ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true clearOperation = true
return nil, fmt.Errorf("restore boot: %w", err) return nil, fmt.Errorf("build machine spec for restore: %w", err)
}
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, err
}
machineState, err := d.runtime.Boot(ctx, spec, usedNetworks)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("boot restored machine: %w", err)
} }
systemVolumeID := d.systemVolumeID(req.MachineID) systemVolumeID := d.systemVolumeID(req.MachineID)
@ -556,86 +512,39 @@ func restoredUserDiskIndex(name string) (int, bool) {
return index, true return index, true
} }
func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest, usedNetworks []firecracker.NetworkAllocation) (map[string]restoredSnapshotArtifact, firecracker.NetworkAllocation, func(), error) { func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (map[string]restoredSnapshotArtifact, func(), error) {
if req.LocalSnapshot != nil { if req.LocalSnapshot != nil {
if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID { if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID) return nil, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
} }
snapshot, err := d.store.GetSnapshot(ctx, snapshotID) snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil { if err != nil {
if err == store.ErrNotFound { if err == store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host") return nil, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host")
} }
return nil, firecracker.NetworkAllocation{}, func() {}, err return nil, func() {}, err
}
restoreNetwork, err := restoreNetworkFromSnapshot(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, fmt.Sprintf("restore network is still in use on this host (runtime_host=%s tap_device=%s)", restoreNetwork.GuestIP(), restoreNetwork.TapName))
} }
artifacts, err := localSnapshotArtifacts(snapshot) artifacts, err := localSnapshotArtifacts(snapshot)
if err != nil { if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error()) return nil, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
} }
return artifacts, restoreNetwork, func() {}, nil return artifacts, func() {}, nil
} }
if req.Snapshot == nil { if req.Snapshot == nil {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("durable snapshot spec is required") return nil, func() {}, fmt.Errorf("durable snapshot spec is required")
}
restoreNetwork, err := restoreNetworkFromDurableSpec(*req.Snapshot)
if err != nil {
snapshot, lookupErr := d.store.GetSnapshot(ctx, snapshotID)
if lookupErr == nil {
restoreNetwork, err = restoreNetworkFromSnapshot(snapshot)
} else if lookupErr != store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, lookupErr
}
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
} }
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID)) stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts) artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil { if err != nil {
_ = os.RemoveAll(stagingDir) _ = os.RemoveAll(stagingDir)
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err) return nil, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err)
} }
return artifacts, restoreNetwork, func() { return artifacts, func() {
_ = os.RemoveAll(stagingDir) _ = os.RemoveAll(stagingDir)
}, nil }, nil
} }
func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
if strings.TrimSpace(spec.SourceRuntimeHost) == "" || strings.TrimSpace(spec.SourceTapDevice) == "" {
return firecracker.NetworkAllocation{}, fmt.Errorf("durable snapshot spec is missing restore network metadata")
}
network, err := firecracker.AllocationFromGuestIP(spec.SourceRuntimeHost, spec.SourceTapDevice)
if err != nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct durable snapshot %q network: %w", spec.SnapshotID, err)
}
return network, nil
}
func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) {
if snap == nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required")
}
if strings.TrimSpace(snap.SourceRuntimeHost) == "" || strings.TrimSpace(snap.SourceTapDevice) == "" {
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snap.ID)
}
network, err := firecracker.AllocationFromGuestIP(snap.SourceRuntimeHost, snap.SourceTapDevice)
if err != nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct snapshot %q network: %w", snap.ID, err)
}
return network, nil
}
func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) { func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) {
if snapshot == nil { if snapshot == nil {
return nil, fmt.Errorf("snapshot is required") return nil, fmt.Errorf("snapshot is required")
@ -662,19 +571,6 @@ func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restored
return restored, nil return restored, nil
} }
func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool {
targetTap := strings.TrimSpace(target.TapName)
for _, network := range used {
if network.GuestIP() == target.GuestIP() {
return true
}
if targetTap != "" && strings.TrimSpace(network.TapName) == targetTap {
return true
}
}
return false
}
func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error { func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error {
message = strings.TrimSpace(message) message = strings.TrimSpace(message)
if message == "" { if message == "" {
@ -682,31 +578,3 @@ func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message
} }
return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message) return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message)
} }
// moveFile copies src to dst then removes src. Works across filesystem boundaries
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
func moveFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer func() {
_ = in.Close()
}()
out, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(out, in); err != nil {
_ = out.Close()
_ = os.Remove(dst)
return err
}
if err := out.Close(); err != nil {
_ = os.Remove(dst)
return err
}
return os.Remove(src)
}

View file

@ -21,21 +21,8 @@ type restoredSnapshotArtifact struct {
LocalPath string LocalPath string
} }
func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) { func buildSnapshotArtifacts(diskPaths []string) ([]model.SnapshotArtifactRecord, error) {
artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2) artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths))
memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, memoryArtifact)
vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, vmstateArtifact)
for _, diskPath := range diskPaths { for _, diskPath := range diskPaths {
base := filepath.Base(diskPath) base := filepath.Base(diskPath)
diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath) diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath)