diff --git a/contract/snapshots.go b/contract/snapshots.go index 1192d6f..6408724 100644 --- a/contract/snapshots.go +++ b/contract/snapshots.go @@ -81,10 +81,15 @@ type UploadSnapshotResponse struct { } type RestoreSnapshotRequest struct { - MachineID MachineID `json:"machine_id"` - Artifact ArtifactRef `json:"artifact"` - Snapshot DurableSnapshotSpec `json:"snapshot"` - GuestConfig *GuestConfig `json:"guest_config,omitempty"` + MachineID MachineID `json:"machine_id"` + Artifact ArtifactRef `json:"artifact"` + LocalSnapshot *LocalSnapshotSpec `json:"local_snapshot,omitempty"` + Snapshot *DurableSnapshotSpec `json:"snapshot,omitempty"` + GuestConfig *GuestConfig `json:"guest_config,omitempty"` +} + +type LocalSnapshotSpec struct { + SnapshotID SnapshotID `json:"snapshot_id"` } type DurableSnapshotSpec struct { diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index c0eb297..8e9446b 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -422,7 +422,19 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { stubGuestSSHPublicKeyReader(hostDaemon) hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } - artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/memory": []byte("mem"), + "/vmstate": []byte("state"), + "/system": []byte("disk"), + }) + defer server.Close() + + artifactRef := contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + } kernelPath := filepath.Join(root, "artifact-kernel") if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil { t.Fatalf("write kernel: %v", err) @@ -452,22 +464,13 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { t.Fatalf("create snapshot: %v", err) } - server := newRestoreArtifactServer(t, map[string][]byte{ - "/kernel": []byte("kernel"), - "/rootfs": []byte("rootfs"), - "/memory": []byte("mem"), - "/vmstate": []byte("state"), - "/system": []byte("disk"), - }) - defer server.Close() - response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ MachineID: "restored", Artifact: contracthost.ArtifactRef{ KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }, - Snapshot: contracthost.DurableSnapshotSpec{ + Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap1", MachineID: "source", ImageID: "image-1", @@ -510,6 +513,134 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { } } +func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { _ = sshListener.Close() }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { _ = vncListener.Close() }() + + startedAt := time.Unix(1700000199, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "restored-local", + Phase: firecracker.PhaseRunning, + PID: 1234, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-local", "root", "run", "firecracker.sock"), + TapName: "fctap0", + StartedAt: &startedAt, + }, + } + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil } + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + }) + defer server.Close() + + artifactRef := contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + } + artifactDir := filepath.Join(root, "artifact") + if err := os.MkdirAll(artifactDir, 0o755); err != nil { + t.Fatalf("create artifact dir: %v", err) + } + kernelPath := filepath.Join(artifactDir, "vmlinux") + rootFSPath := filepath.Join(artifactDir, "rootfs.ext4") + if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil { + t.Fatalf("write kernel: %v", err) + } + if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil { + t.Fatalf("write rootfs: %v", err) + } + if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{ + Ref: artifactRef, + LocalKey: "artifact", + LocalDir: artifactDir, + KernelImagePath: kernelPath, + RootFSPath: rootFSPath, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("put artifact: %v", err) + } + + snapshotDir := filepath.Join(root, "snapshots", "snap-local") + if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + memoryPath := filepath.Join(snapshotDir, "memory.bin") + vmstatePath := filepath.Join(snapshotDir, "vmstate.bin") + systemPath := filepath.Join(snapshotDir, "system.img") + if err := os.WriteFile(memoryPath, []byte("mem"), 0o644); err != nil { + t.Fatalf("write memory: %v", err) + } + if err := os.WriteFile(vmstatePath, []byte("state"), 0o644); err != nil { + t.Fatalf("write vmstate: %v", err) + } + if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil { + t.Fatalf("write system disk: %v", err) + } + if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: "snap-local", + MachineID: "source", + Artifact: artifactRef, + MemFilePath: memoryPath, + StateFilePath: vmstatePath, + DiskPaths: []string{systemPath}, + Artifacts: []model.SnapshotArtifactRecord{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", LocalPath: memoryPath, SizeBytes: 3}, + {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", LocalPath: vmstatePath, SizeBytes: 5}, + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4}, + }, + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap-local", contracthost.RestoreSnapshotRequest{ + MachineID: "restored-local", + Artifact: artifactRef, + LocalSnapshot: &contracthost.LocalSnapshotSpec{ + SnapshotID: "snap-local", + }, + GuestConfig: &contracthost.GuestConfig{Hostname: "restored-local-shell"}, + }) + if err != nil { + t.Fatalf("restore snapshot: %v", err) + } + if response.Machine.ID != "restored-local" { + t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) + } + if runtime.restoreCalls != 1 { + t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) + } + if runtime.lastLoadSpec.Network == nil { + t.Fatalf("restore boot should preserve local snapshot network") + } + if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { + t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2") + } + if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" { + t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0") + } +} + func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -565,7 +696,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }, - Snapshot: contracthost.DurableSnapshotSpec{ + Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap1", MachineID: "source", ImageID: "image-1", @@ -666,7 +797,7 @@ func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { KernelImageURL: "https://example.com/kernel", RootFSURL: "https://example.com/rootfs", }, - Snapshot: contracthost.DurableSnapshotSpec{ + Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap1", MachineID: "source", ImageID: "image-1", diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 3badab2..e49e59b 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -14,10 +14,10 @@ import ( "testing" "time" - contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" hoststore "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" ) type blockingPublishedPortStore struct { @@ -509,7 +509,7 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }, - Snapshot: contracthost.DurableSnapshotSpec{ + Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap-exhausted", MachineID: "source", ImageID: "image-1", @@ -708,7 +708,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) { KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }, - Snapshot: contracthost.DurableSnapshotSpec{ + Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap-clean", MachineID: "source", ImageID: "image-1", @@ -759,7 +759,7 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) KernelImageURL: server.URL + "/kernel", RootFSURL: server.URL + "/rootfs", }, - Snapshot: contracthost.DurableSnapshotSpec{ + Snapshot: &contracthost.DurableSnapshotSpec{ SnapshotID: "snap-fail-clean", MachineID: "source", ImageID: "image-1", diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index f14ffd6..0255d0e 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -11,12 +11,16 @@ import ( "strings" "time" + "golang.org/x/sync/errgroup" + "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" contracthost "github.com/getcompanion-ai/computer-host/contract" ) +const localSnapshotRestoreUnavailablePrefix = "local snapshot restore unavailable" + func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) { unlock := d.lockMachine(machineID) defer unlock() @@ -193,20 +197,31 @@ func (d *Daemon) UploadSnapshot(ctx context.Context, snapshotID contracthost.Sna response := &contracthost.UploadSnapshotResponse{ Artifacts: make([]contracthost.UploadedSnapshotArtifact, 0, len(req.Artifacts)), } - for _, upload := range req.Artifacts { - artifact, ok := artifactIndex[upload.ArtifactID] - if !ok { - return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID) - } - completedParts, err := uploadSnapshotArtifact(ctx, artifact.LocalPath, upload.Parts) - if err != nil { - return nil, fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err) - } - response.Artifacts = append(response.Artifacts, contracthost.UploadedSnapshotArtifact{ - ArtifactID: upload.ArtifactID, - CompletedParts: completedParts, + uploads := make([]contracthost.UploadedSnapshotArtifact, len(req.Artifacts)) + group, groupCtx := errgroup.WithContext(ctx) + for i, upload := range req.Artifacts { + i := i + upload := upload + group.Go(func() error { + artifact, ok := artifactIndex[upload.ArtifactID] + if !ok { + return fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID) + } + completedParts, err := uploadSnapshotArtifact(groupCtx, artifact.LocalPath, upload.Parts) + if err != nil { + return fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err) + } + uploads[i] = contracthost.UploadedSnapshotArtifact{ + ArtifactID: upload.ArtifactID, + CompletedParts: completedParts, + } + return nil }) } + if err := group.Wait(); err != nil { + return nil, err + } + response.Artifacts = append(response.Artifacts, uploads...) return response, nil } @@ -215,12 +230,18 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn if err := validateMachineID(req.MachineID); err != nil { return nil, err } - if req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID { - return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID) - } if err := validateArtifactRef(req.Artifact); err != nil { return nil, err } + if req.LocalSnapshot == nil && req.Snapshot == nil { + return nil, fmt.Errorf("restore request must include local_snapshot or snapshot") + } + if req.LocalSnapshot != nil && req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID { + return nil, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID) + } + if req.Snapshot != nil && req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID { + return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID) + } if err := validateGuestConfig(req.GuestConfig); err != nil { return nil, err } @@ -258,30 +279,18 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn if err != nil { return nil, err } - restoreNetwork, err := d.resolveRestoreNetwork(ctx, snapshotID, req.Snapshot) + restoredArtifacts, restoreNetwork, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req, usedNetworks) if err != nil { clearOperation = true return nil, err } - if networkAllocationInUse(restoreNetwork, usedNetworks) { - clearOperation = true - return nil, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName) - } - + defer cleanupRestoreArtifacts() artifact, err := d.ensureArtifact(ctx, req.Artifact) if err != nil { + clearOperation = true return nil, fmt.Errorf("ensure artifact for restore: %w", err) } - stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID)) - restoredArtifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts) - if err != nil { - _ = os.RemoveAll(stagingDir) - clearOperation = true - return nil, fmt.Errorf("download durable snapshot artifacts: %w", err) - } - defer func() { _ = os.RemoveAll(stagingDir) }() - // COW-copy system disk from snapshot to new machine's disk dir. newSystemDiskPath := d.systemVolumePath(req.MachineID) if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil { @@ -515,19 +524,59 @@ func restoredUserDiskIndex(name string) (int, bool) { return index, true } -func (d *Daemon) resolveRestoreNetwork(ctx context.Context, snapshotID contracthost.SnapshotID, spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) { - if network, err := restoreNetworkFromDurableSpec(spec); err == nil { - return network, nil +func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest, usedNetworks []firecracker.NetworkAllocation) (map[string]restoredSnapshotArtifact, firecracker.NetworkAllocation, func(), error) { + if req.LocalSnapshot != nil { + if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID { + return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID) + } + snapshot, err := d.store.GetSnapshot(ctx, snapshotID) + if err != nil { + if err == store.ErrNotFound { + return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host") + } + return nil, firecracker.NetworkAllocation{}, func() {}, err + } + restoreNetwork, err := restoreNetworkFromSnapshot(snapshot) + if err != nil { + return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error()) + } + if networkAllocationInUse(restoreNetwork, usedNetworks) { + return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, fmt.Sprintf("restore network is still in use on this host (runtime_host=%s tap_device=%s)", restoreNetwork.GuestIP(), restoreNetwork.TapName)) + } + artifacts, err := localSnapshotArtifacts(snapshot) + if err != nil { + return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error()) + } + return artifacts, restoreNetwork, func() {}, nil } - snapshot, err := d.store.GetSnapshot(ctx, snapshotID) - if err == nil { - return restoreNetworkFromSnapshot(snapshot) + if req.Snapshot == nil { + return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("durable snapshot spec is required") } - if err != store.ErrNotFound { - return firecracker.NetworkAllocation{}, err + restoreNetwork, err := restoreNetworkFromDurableSpec(*req.Snapshot) + if err != nil { + snapshot, lookupErr := d.store.GetSnapshot(ctx, snapshotID) + if lookupErr == nil { + restoreNetwork, err = restoreNetworkFromSnapshot(snapshot) + } else if lookupErr != store.ErrNotFound { + return nil, firecracker.NetworkAllocation{}, func() {}, lookupErr + } + if err != nil { + return nil, firecracker.NetworkAllocation{}, func() {}, err + } } - return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snapshotID) + if networkAllocationInUse(restoreNetwork, usedNetworks) { + return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName) + } + stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID)) + artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts) + if err != nil { + _ = os.RemoveAll(stagingDir) + return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err) + } + return artifacts, restoreNetwork, func() { + _ = os.RemoveAll(stagingDir) + }, nil } func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) { @@ -555,6 +604,32 @@ func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.Network return network, nil } +func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) { + if snapshot == nil { + return nil, fmt.Errorf("snapshot is required") + } + restored := make(map[string]restoredSnapshotArtifact, len(snapshot.Artifacts)) + for _, artifact := range snapshot.Artifacts { + if strings.TrimSpace(artifact.LocalPath) == "" { + return nil, fmt.Errorf("snapshot %q artifact %q is missing a local path", snapshot.ID, artifact.ID) + } + if _, err := os.Stat(artifact.LocalPath); err != nil { + return nil, fmt.Errorf("snapshot %q artifact %q is unavailable at %q: %w", snapshot.ID, artifact.ID, artifact.LocalPath, err) + } + restored[artifact.Name] = restoredSnapshotArtifact{ + Artifact: contracthost.SnapshotArtifact{ + ID: artifact.ID, + Kind: artifact.Kind, + Name: artifact.Name, + SizeBytes: artifact.SizeBytes, + SHA256Hex: artifact.SHA256Hex, + }, + LocalPath: artifact.LocalPath, + } + } + return restored, nil +} + func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool { targetTap := strings.TrimSpace(target.TapName) for _, network := range used { @@ -568,6 +643,14 @@ func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecra return false } +func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error { + message = strings.TrimSpace(message) + if message == "" { + message = "local restore is unavailable" + } + return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message) +} + // moveFile copies src to dst then removes src. Works across filesystem boundaries // unlike os.Rename, which is needed when moving files out of /proc//root/. func moveFile(src, dst string) error {