feat: local first snapshot implementation end to end

This commit is contained in:
Harivansh Rathi 2026-04-10 21:32:09 +00:00
parent fc21e897ea
commit 30282928f5
4 changed files with 279 additions and 60 deletions

View file

@ -422,7 +422,19 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
})
defer server.Close()
artifactRef := contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
}
kernelPath := filepath.Join(root, "artifact-kernel")
if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil {
t.Fatalf("write kernel: %v", err)
@ -452,22 +464,13 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
t.Fatalf("create snapshot: %v", err)
}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
})
defer server.Close()
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
MachineID: "restored",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap1",
MachineID: "source",
ImageID: "image-1",
@ -510,6 +513,134 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
}
}
func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() { _ = sshListener.Close() }()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() { _ = vncListener.Close() }()
startedAt := time.Unix(1700000199, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
ID: "restored-local",
Phase: firecracker.PhaseRunning,
PID: 1234,
RuntimeHost: "127.0.0.1",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-local", "root", "run", "firecracker.sock"),
TapName: "fctap0",
StartedAt: &startedAt,
},
}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
})
defer server.Close()
artifactRef := contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
}
artifactDir := filepath.Join(root, "artifact")
if err := os.MkdirAll(artifactDir, 0o755); err != nil {
t.Fatalf("create artifact dir: %v", err)
}
kernelPath := filepath.Join(artifactDir, "vmlinux")
rootFSPath := filepath.Join(artifactDir, "rootfs.ext4")
if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil {
t.Fatalf("write kernel: %v", err)
}
if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil {
t.Fatalf("write rootfs: %v", err)
}
if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{
Ref: artifactRef,
LocalKey: "artifact",
LocalDir: artifactDir,
KernelImagePath: kernelPath,
RootFSPath: rootFSPath,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("put artifact: %v", err)
}
snapshotDir := filepath.Join(root, "snapshots", "snap-local")
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
memoryPath := filepath.Join(snapshotDir, "memory.bin")
vmstatePath := filepath.Join(snapshotDir, "vmstate.bin")
systemPath := filepath.Join(snapshotDir, "system.img")
if err := os.WriteFile(memoryPath, []byte("mem"), 0o644); err != nil {
t.Fatalf("write memory: %v", err)
}
if err := os.WriteFile(vmstatePath, []byte("state"), 0o644); err != nil {
t.Fatalf("write vmstate: %v", err)
}
if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil {
t.Fatalf("write system disk: %v", err)
}
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap-local",
MachineID: "source",
Artifact: artifactRef,
MemFilePath: memoryPath,
StateFilePath: vmstatePath,
DiskPaths: []string{systemPath},
Artifacts: []model.SnapshotArtifactRecord{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", LocalPath: memoryPath, SizeBytes: 3},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", LocalPath: vmstatePath, SizeBytes: 5},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4},
},
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create snapshot: %v", err)
}
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap-local", contracthost.RestoreSnapshotRequest{
MachineID: "restored-local",
Artifact: artifactRef,
LocalSnapshot: &contracthost.LocalSnapshotSpec{
SnapshotID: "snap-local",
},
GuestConfig: &contracthost.GuestConfig{Hostname: "restored-local-shell"},
})
if err != nil {
t.Fatalf("restore snapshot: %v", err)
}
if response.Machine.ID != "restored-local" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
}
if runtime.restoreCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
}
if runtime.lastLoadSpec.Network == nil {
t.Fatalf("restore boot should preserve local snapshot network")
}
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
}
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
}
}
func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -565,7 +696,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap1",
MachineID: "source",
ImageID: "image-1",
@ -666,7 +797,7 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) {
KernelImageURL: "https://example.com/kernel",
RootFSURL: "https://example.com/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap1",
MachineID: "source",
ImageID: "image-1",

View file

@ -14,10 +14,10 @@ import (
"testing"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
hoststore "github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type blockingPublishedPortStore struct {
@ -509,7 +509,7 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap-exhausted",
MachineID: "source",
ImageID: "image-1",
@ -708,7 +708,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) {
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap-clean",
MachineID: "source",
ImageID: "image-1",
@ -759,7 +759,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T)
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap-fail-clean",
MachineID: "source",
ImageID: "image-1",

View file

@ -11,12 +11,16 @@ import (
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const localSnapshotRestoreUnavailablePrefix = "local snapshot restore unavailable"
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
unlock := d.lockMachine(machineID)
defer unlock()
@ -193,20 +197,31 @@ func (d *Daemon) UploadSnapshot(ctx context.Context, snapshotID contracthost.Sna
response := &contracthost.UploadSnapshotResponse{
Artifacts: make([]contracthost.UploadedSnapshotArtifact, 0, len(req.Artifacts)),
}
for _, upload := range req.Artifacts {
artifact, ok := artifactIndex[upload.ArtifactID]
if !ok {
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID)
}
completedParts, err := uploadSnapshotArtifact(ctx, artifact.LocalPath, upload.Parts)
if err != nil {
return nil, fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err)
}
response.Artifacts = append(response.Artifacts, contracthost.UploadedSnapshotArtifact{
ArtifactID: upload.ArtifactID,
CompletedParts: completedParts,
uploads := make([]contracthost.UploadedSnapshotArtifact, len(req.Artifacts))
group, groupCtx := errgroup.WithContext(ctx)
for i, upload := range req.Artifacts {
i := i
upload := upload
group.Go(func() error {
artifact, ok := artifactIndex[upload.ArtifactID]
if !ok {
return fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID)
}
completedParts, err := uploadSnapshotArtifact(groupCtx, artifact.LocalPath, upload.Parts)
if err != nil {
return fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err)
}
uploads[i] = contracthost.UploadedSnapshotArtifact{
ArtifactID: upload.ArtifactID,
CompletedParts: completedParts,
}
return nil
})
}
if err := group.Wait(); err != nil {
return nil, err
}
response.Artifacts = append(response.Artifacts, uploads...)
return response, nil
}
@ -215,12 +230,18 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
if err := validateMachineID(req.MachineID); err != nil {
return nil, err
}
if req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID)
}
if err := validateArtifactRef(req.Artifact); err != nil {
return nil, err
}
if req.LocalSnapshot == nil && req.Snapshot == nil {
return nil, fmt.Errorf("restore request must include local_snapshot or snapshot")
}
if req.LocalSnapshot != nil && req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
}
if req.Snapshot != nil && req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID)
}
if err := validateGuestConfig(req.GuestConfig); err != nil {
return nil, err
}
@ -258,30 +279,18 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
if err != nil {
return nil, err
}
restoreNetwork, err := d.resolveRestoreNetwork(ctx, snapshotID, req.Snapshot)
restoredArtifacts, restoreNetwork, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req, usedNetworks)
if err != nil {
clearOperation = true
return nil, err
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
clearOperation = true
return nil, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
}
defer cleanupRestoreArtifacts()
artifact, err := d.ensureArtifact(ctx, req.Artifact)
if err != nil {
clearOperation = true
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
}
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
restoredArtifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil {
_ = os.RemoveAll(stagingDir)
clearOperation = true
return nil, fmt.Errorf("download durable snapshot artifacts: %w", err)
}
defer func() { _ = os.RemoveAll(stagingDir) }()
// COW-copy system disk from snapshot to new machine's disk dir.
newSystemDiskPath := d.systemVolumePath(req.MachineID)
if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil {
@ -515,19 +524,59 @@ func restoredUserDiskIndex(name string) (int, bool) {
return index, true
}
func (d *Daemon) resolveRestoreNetwork(ctx context.Context, snapshotID contracthost.SnapshotID, spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
if network, err := restoreNetworkFromDurableSpec(spec); err == nil {
return network, nil
func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest, usedNetworks []firecracker.NetworkAllocation) (map[string]restoredSnapshotArtifact, firecracker.NetworkAllocation, func(), error) {
if req.LocalSnapshot != nil {
if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
}
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil {
if err == store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host")
}
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
restoreNetwork, err := restoreNetworkFromSnapshot(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, fmt.Sprintf("restore network is still in use on this host (runtime_host=%s tap_device=%s)", restoreNetwork.GuestIP(), restoreNetwork.TapName))
}
artifacts, err := localSnapshotArtifacts(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
return artifacts, restoreNetwork, func() {}, nil
}
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err == nil {
return restoreNetworkFromSnapshot(snapshot)
if req.Snapshot == nil {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("durable snapshot spec is required")
}
if err != store.ErrNotFound {
return firecracker.NetworkAllocation{}, err
restoreNetwork, err := restoreNetworkFromDurableSpec(*req.Snapshot)
if err != nil {
snapshot, lookupErr := d.store.GetSnapshot(ctx, snapshotID)
if lookupErr == nil {
restoreNetwork, err = restoreNetworkFromSnapshot(snapshot)
} else if lookupErr != store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, lookupErr
}
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
}
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snapshotID)
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
}
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil {
_ = os.RemoveAll(stagingDir)
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err)
}
return artifacts, restoreNetwork, func() {
_ = os.RemoveAll(stagingDir)
}, nil
}
func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
@ -555,6 +604,32 @@ func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.Network
return network, nil
}
func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) {
if snapshot == nil {
return nil, fmt.Errorf("snapshot is required")
}
restored := make(map[string]restoredSnapshotArtifact, len(snapshot.Artifacts))
for _, artifact := range snapshot.Artifacts {
if strings.TrimSpace(artifact.LocalPath) == "" {
return nil, fmt.Errorf("snapshot %q artifact %q is missing a local path", snapshot.ID, artifact.ID)
}
if _, err := os.Stat(artifact.LocalPath); err != nil {
return nil, fmt.Errorf("snapshot %q artifact %q is unavailable at %q: %w", snapshot.ID, artifact.ID, artifact.LocalPath, err)
}
restored[artifact.Name] = restoredSnapshotArtifact{
Artifact: contracthost.SnapshotArtifact{
ID: artifact.ID,
Kind: artifact.Kind,
Name: artifact.Name,
SizeBytes: artifact.SizeBytes,
SHA256Hex: artifact.SHA256Hex,
},
LocalPath: artifact.LocalPath,
}
}
return restored, nil
}
func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool {
targetTap := strings.TrimSpace(target.TapName)
for _, network := range used {
@ -568,6 +643,14 @@ func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecra
return false
}
func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error {
message = strings.TrimSpace(message)
if message == "" {
message = "local restore is unavailable"
}
return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message)
}
// moveFile copies src to dst then removes src. Works across filesystem boundaries
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
func moveFile(src, dst string) error {