From b5c97aef0789b0fef785c232e2e7d0f7a61d317a Mon Sep 17 00:00:00 2001 From: Hari <73809867+harivansh-afk@users.noreply.github.com> Date: Wed, 8 Apr 2026 22:21:46 -0400 Subject: [PATCH] host api alignment (#7) * feat: add Firecracker API client methods for VM pause/resume and snapshots Add PatchVm, GetVm, PutSnapshotCreate, and PutSnapshotLoad methods to the API client, along with supporting types (VmState, SnapshotCreateParams, SnapshotLoadParams, MemBackend). * feat: add snapshot data layer - contract types, model, store, config Add SnapshotID and snapshot contract types, SnapshotRecord model, store interface CRUD methods with file store implementation, snapshot paths helper, SnapshotsDir config, and directory creation. * feat: add runtime methods for VM pause, resume, snapshot, and restore Implement Pause, Resume, CreateSnapshot, and RestoreBoot on the firecracker Runtime. RestoreBoot launches a jailer, stages snapshot files into the chroot, loads the snapshot, and resumes the VM. * feat: add daemon snapshot create, restore, and reconciliation logic Implement CreateSnapshot (pause, snapshot, COW-copy disk, resume), RestoreSnapshot (COW-copy disk, RestoreBoot, wait for guest), GetSnapshot, ListSnapshots, DeleteSnapshotByID, and crash recovery reconciliation for snapshot and restore operations. * feat: add HTTP endpoints for snapshot create, get, list, delete, restore Wire 5 snapshot routes: POST /machines/{id}/snapshots (create), GET /machines/{id}/snapshots (list), GET /snapshots/{id} (get), DELETE /snapshots/{id} (delete), POST /snapshots/{id}/restore (restore). * fix: cross-device rename, restore network, and snapshot cleanup - Replace os.Rename with copy+remove for moving snapshot files out of /proc//root/ (cross-device link error on Linux) - Reconfigure network interface after snapshot load so the restored VM uses its own tap device instead of the source VM's - Clean partial snapshot dirs immediately on failure instead of only via reconcile - Reject snapshot requests while a machine operation is already pending * fix: test and modify snapshot runtime * feat: snapshot lifecycle update, align runtime issues between host image and daemon --- contract/snapshots.go | 31 +++ internal/config/config.go | 5 + internal/daemon/daemon.go | 20 +- internal/daemon/daemon_test.go | 243 +++++++++++++++++- internal/daemon/guest_identity.go | 59 +++++ internal/daemon/lifecycle.go | 41 +++ internal/daemon/snapshot.go | 407 ++++++++++++++++++++++++++++++ internal/firecracker/api.go | 63 +++++ internal/firecracker/api_test.go | 54 ++++ internal/firecracker/launch.go | 8 + internal/firecracker/paths.go | 15 ++ internal/firecracker/runtime.go | 149 +++++++++++ internal/firecracker/state.go | 21 ++ internal/httpapi/handlers.go | 80 ++++++ internal/model/types.go | 29 ++- internal/store/file_store.go | 78 +++++- internal/store/store.go | 4 + 17 files changed, 1287 insertions(+), 20 deletions(-) create mode 100644 contract/snapshots.go create mode 100644 internal/daemon/guest_identity.go create mode 100644 internal/daemon/snapshot.go create mode 100644 internal/firecracker/api_test.go diff --git a/contract/snapshots.go b/contract/snapshots.go new file mode 100644 index 0000000..96b6d5c --- /dev/null +++ b/contract/snapshots.go @@ -0,0 +1,31 @@ +package host + +import "time" + +type SnapshotID string + +type Snapshot struct { + ID SnapshotID `json:"id"` + MachineID MachineID `json:"machine_id"` + CreatedAt time.Time `json:"created_at"` +} + +type CreateSnapshotResponse struct { + Snapshot Snapshot `json:"snapshot"` +} + +type GetSnapshotResponse struct { + Snapshot Snapshot `json:"snapshot"` +} + +type ListSnapshotsResponse struct { + Snapshots []Snapshot `json:"snapshots"` +} + +type RestoreSnapshotRequest struct { + MachineID MachineID `json:"machine_id"` +} + +type RestoreSnapshotResponse struct { + Machine Machine `json:"machine"` +} diff --git a/internal/config/config.go b/internal/config/config.go index e353293..b5a2411 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,6 +18,7 @@ type Config struct { OperationsPath string ArtifactsDir string MachineDisksDir string + SnapshotsDir string RuntimeDir string SocketPath string EgressInterface string @@ -34,6 +35,7 @@ func Load() (Config, error) { OperationsPath: filepath.Join(rootDir, "state", "ops.json"), ArtifactsDir: filepath.Join(rootDir, "artifacts"), MachineDisksDir: filepath.Join(rootDir, "machine-disks"), + SnapshotsDir: filepath.Join(rootDir, "snapshots"), RuntimeDir: filepath.Join(rootDir, "runtime"), SocketPath: filepath.Join(rootDir, defaultSocketName), EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")), @@ -69,6 +71,9 @@ func (c Config) Validate() error { if strings.TrimSpace(c.MachineDisksDir) == "" { return fmt.Errorf("machine disks dir is required") } + if strings.TrimSpace(c.SnapshotsDir) == "" { + return fmt.Errorf("snapshots dir is required") + } if strings.TrimSpace(c.RuntimeDir) == "" { return fmt.Errorf("runtime dir is required") } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index f955306..46dd9e5 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -29,6 +29,10 @@ type Runtime interface { Boot(context.Context, firecracker.MachineSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error) Inspect(firecracker.MachineState) (*firecracker.MachineState, error) Delete(context.Context, firecracker.MachineState) error + Pause(context.Context, firecracker.MachineState) error + Resume(context.Context, firecracker.MachineState) error + CreateSnapshot(context.Context, firecracker.MachineState, firecracker.SnapshotPaths) error + RestoreBoot(context.Context, firecracker.SnapshotLoadSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error) } type Daemon struct { @@ -36,6 +40,8 @@ type Daemon struct { store store.Store runtime Runtime + reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error + locksMu sync.Mutex machineLocks map[contracthost.MachineID]*sync.Mutex artifactLocks map[string]*sync.Mutex @@ -51,18 +57,20 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err if runtime == nil { return nil, fmt.Errorf("runtime is required") } - for _, dir := range []string{cfg.ArtifactsDir, cfg.MachineDisksDir, cfg.RuntimeDir} { + for _, dir := range []string{cfg.ArtifactsDir, cfg.MachineDisksDir, cfg.SnapshotsDir, cfg.RuntimeDir} { if err := os.MkdirAll(dir, 0o755); err != nil { return nil, fmt.Errorf("create daemon dir %q: %w", dir, err) } } daemon := &Daemon{ - config: cfg, - store: store, - runtime: runtime, - machineLocks: make(map[contracthost.MachineID]*sync.Mutex), - artifactLocks: make(map[string]*sync.Mutex), + config: cfg, + store: store, + runtime: runtime, + reconfigureGuestIdentity: nil, + machineLocks: make(map[contracthost.MachineID]*sync.Mutex), + artifactLocks: make(map[string]*sync.Mutex), } + daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH if err := daemon.ensureBackendSSHKeyPair(); err != nil { return nil, err } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 7d60f36..e347b69 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -15,15 +15,18 @@ import ( appconfig "github.com/getcompanion-ai/computer-host/internal/config" "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" contracthost "github.com/getcompanion-ai/computer-host/contract" ) type fakeRuntime struct { - bootState firecracker.MachineState - bootCalls int - deleteCalls []firecracker.MachineState - lastSpec firecracker.MachineSpec + bootState firecracker.MachineState + bootCalls int + restoreCalls int + deleteCalls []firecracker.MachineState + lastSpec firecracker.MachineSpec + lastLoadSpec firecracker.SnapshotLoadSpec } func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) { @@ -43,6 +46,24 @@ func (f *fakeRuntime) Delete(_ context.Context, state firecracker.MachineState) return nil } +func (f *fakeRuntime) Pause(_ context.Context, _ firecracker.MachineState) error { + return nil +} + +func (f *fakeRuntime) Resume(_ context.Context, _ firecracker.MachineState) error { + return nil +} + +func (f *fakeRuntime) CreateSnapshot(_ context.Context, _ firecracker.MachineState, _ firecracker.SnapshotPaths) error { + return nil +} + +func (f *fakeRuntime) RestoreBoot(_ context.Context, spec firecracker.SnapshotLoadSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) { + f.restoreCalls++ + f.lastLoadSpec = spec + return &f.bootState, nil +} + func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -223,6 +244,219 @@ func TestNewEnsuresBackendSSHKeyPair(t *testing.T) { } } +func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil } + + artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} + kernelPath := filepath.Join(root, "artifact-kernel") + if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil { + t.Fatalf("write kernel: %v", err) + } + if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{ + Ref: artifactRef, + LocalKey: "artifact", + LocalDir: filepath.Join(root, "artifact"), + KernelImagePath: kernelPath, + RootFSPath: filepath.Join(root, "artifact-rootfs"), + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("put artifact: %v", err) + } + + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: "source", + Artifact: artifactRef, + SystemVolumeID: "source-system", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap0", + Phase: contracthost.MachinePhaseRunning, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create source machine: %v", err) + } + + snapDisk := filepath.Join(root, "snapshots", "snap1", "system.img") + if err := os.MkdirAll(filepath.Dir(snapDisk), 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { + t.Fatalf("write snapshot disk: %v", err) + } + if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: "snap1", + MachineID: "source", + Artifact: artifactRef, + MemFilePath: filepath.Join(root, "snapshots", "snap1", "memory.bin"), + StateFilePath: filepath.Join(root, "snapshots", "snap1", "vmstate.bin"), + DiskPaths: []string{snapDisk}, + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ + MachineID: "restored", + }) + if err == nil { + t.Fatal("expected restore rejection while source is running") + } + if !strings.Contains(err.Error(), `source machine "source" is running`) { + t.Fatalf("unexpected restore error: %v", err) + } + if runtime.restoreCalls != 0 { + t.Fatalf("restore boot should not run when source machine is still running: got %d", runtime.restoreCalls) + } + + ops, err := fileStore.ListOperations(context.Background()) + if err != nil { + t.Fatalf("list operations: %v", err) + } + if len(ops) != 0 { + t.Fatalf("operation journal should be empty after handled restore rejection: got %d entries", len(ops)) + } +} + +func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer sshListener.Close() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer vncListener.Close() + + startedAt := time.Unix(1700000099, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "restored", + Phase: firecracker.PhaseRunning, + PID: 1234, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored", "root", "run", "firecracker.sock"), + TapName: "fctap0", + StartedAt: &startedAt, + }, + } + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + var reconfiguredHost string + var reconfiguredMachine contracthost.MachineID + hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID) error { + reconfiguredHost = host + reconfiguredMachine = machineID + return nil + } + + artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} + kernelPath := filepath.Join(root, "artifact-kernel") + rootFSPath := filepath.Join(root, "artifact-rootfs") + if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil { + t.Fatalf("write kernel: %v", err) + } + if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil { + t.Fatalf("write rootfs: %v", err) + } + if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{ + Ref: artifactRef, + LocalKey: "artifact", + LocalDir: filepath.Join(root, "artifact"), + KernelImagePath: kernelPath, + RootFSPath: rootFSPath, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("put artifact: %v", err) + } + + snapDir := filepath.Join(root, "snapshots", "snap1") + if err := os.MkdirAll(snapDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + snapDisk := filepath.Join(snapDir, "system.img") + if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { + t.Fatalf("write snapshot disk: %v", err) + } + if err := os.WriteFile(filepath.Join(snapDir, "memory.bin"), []byte("mem"), 0o644); err != nil { + t.Fatalf("write memory snapshot: %v", err) + } + if err := os.WriteFile(filepath.Join(snapDir, "vmstate.bin"), []byte("state"), 0o644); err != nil { + t.Fatalf("write vmstate snapshot: %v", err) + } + if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: "snap1", + MachineID: "source", + Artifact: artifactRef, + MemFilePath: filepath.Join(snapDir, "memory.bin"), + StateFilePath: filepath.Join(snapDir, "vmstate.bin"), + DiskPaths: []string{snapDisk}, + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ + MachineID: "restored", + }) + if err != nil { + t.Fatalf("restore snapshot: %v", err) + } + if response.Machine.ID != "restored" { + t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) + } + if runtime.restoreCalls != 1 { + t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) + } + if runtime.lastLoadSpec.Network == nil { + t.Fatal("restore boot did not receive snapshot network") + } + if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { + t.Fatalf("restored guest network mismatch: got %q want %q", got, "172.16.0.2") + } + if runtime.lastLoadSpec.KernelImagePath != kernelPath { + t.Fatalf("restore boot kernel path mismatch: got %q want %q", runtime.lastLoadSpec.KernelImagePath, kernelPath) + } + if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" { + t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine) + } + + machine, err := fileStore.GetMachine(context.Background(), "restored") + if err != nil { + t.Fatalf("get restored machine: %v", err) + } + if machine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("restored machine phase mismatch: got %q", machine.Phase) + } + + ops, err := fileStore.ListOperations(context.Background()) + if err != nil { + t.Fatalf("list operations: %v", err) + } + if len(ops) != 0 { + t.Fatalf("operation journal should be empty after successful restore: got %d entries", len(ops)) + } +} + func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) { t.Parallel() @@ -282,6 +516,7 @@ func testConfig(root string) appconfig.Config { OperationsPath: filepath.Join(root, "state", "ops.json"), ArtifactsDir: filepath.Join(root, "artifacts"), MachineDisksDir: filepath.Join(root, "machine-disks"), + SnapshotsDir: filepath.Join(root, "snapshots"), RuntimeDir: filepath.Join(root, "runtime"), SocketPath: filepath.Join(root, "firecracker-host.sock"), EgressInterface: "eth0", diff --git a/internal/daemon/guest_identity.go b/internal/daemon/guest_identity.go new file mode 100644 index 0000000..bcd6afe --- /dev/null +++ b/internal/daemon/guest_identity.go @@ -0,0 +1,59 @@ +package daemon + +import ( + "context" + "fmt" + "os/exec" + "strconv" + "strings" + + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHost string, machineID contracthost.MachineID) error { + runtimeHost = strings.TrimSpace(runtimeHost) + machineName := strings.TrimSpace(string(machineID)) + if runtimeHost == "" { + return fmt.Errorf("guest runtime host is required") + } + if machineName == "" { + return fmt.Errorf("machine id is required") + } + + privateKeyPath := d.backendSSHPrivateKeyPath() + remoteScript := fmt.Sprintf(`set -euo pipefail +machine_name=%s +printf '%%s\n' "$machine_name" >/etc/microagent/machine-name +printf '%%s\n' "$machine_name" >/etc/hostname +cat >/etc/hosts </dev/null 2>&1 || true +`, strconv.Quote(machineName)) + + cmd := exec.CommandContext( + ctx, + "ssh", + "-i", privateKeyPath, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "IdentitiesOnly=yes", + "-o", "BatchMode=yes", + "-p", strconv.Itoa(int(defaultSSHPort)), + "node@"+runtimeHost, + "sudo bash -lc "+shellSingleQuote(remoteScript), + ) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("reconfigure guest identity over ssh: %w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func shellSingleQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index c444934..580d0d6 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -128,6 +128,14 @@ func (d *Daemon) Reconcile(ctx context.Context) error { if err := d.reconcileDelete(ctx, operation.MachineID); err != nil { return err } + case model.MachineOperationSnapshot: + if err := d.reconcileSnapshot(ctx, operation); err != nil { + return err + } + case model.MachineOperationRestore: + if err := d.reconcileRestore(ctx, operation); err != nil { + return err + } default: return fmt.Errorf("unsupported operation type %q", operation.Type) } @@ -325,3 +333,36 @@ func (d *Daemon) detachVolumesForMachine(ctx context.Context, machineID contract } return nil } + +func (d *Daemon) reconcileSnapshot(ctx context.Context, operation model.OperationRecord) error { + if operation.SnapshotID == nil { + return d.store.DeleteOperation(ctx, operation.MachineID) + } + _, err := d.store.GetSnapshot(ctx, *operation.SnapshotID) + if err == nil { + // Snapshot completed successfully, just clear the journal + return d.store.DeleteOperation(ctx, operation.MachineID) + } + // Snapshot did not complete: clean up partial snapshot directory and resume the machine + snapshotDir := filepath.Join(d.config.SnapshotsDir, string(*operation.SnapshotID)) + _ = os.RemoveAll(snapshotDir) + + // Try to resume the source machine in case it was left paused + record, err := d.store.GetMachine(ctx, operation.MachineID) + if err == nil && record.Phase == contracthost.MachinePhaseRunning && record.PID > 0 { + _ = d.runtime.Resume(ctx, machineToRuntimeState(*record)) + } + return d.store.DeleteOperation(ctx, operation.MachineID) +} + +func (d *Daemon) reconcileRestore(ctx context.Context, operation model.OperationRecord) error { + _, err := d.store.GetMachine(ctx, operation.MachineID) + if err == nil { + // Restore completed, clear journal + return d.store.DeleteOperation(ctx, operation.MachineID) + } + // Restore did not complete: clean up partial machine directory and disk + _ = os.RemoveAll(filepath.Dir(d.systemVolumePath(operation.MachineID))) + _ = os.RemoveAll(d.machineRuntimeBaseDir(operation.MachineID)) + return d.store.DeleteOperation(ctx, operation.MachineID) +} diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go new file mode 100644 index 0000000..a9da445 --- /dev/null +++ b/internal/daemon/snapshot.go @@ -0,0 +1,407 @@ +package daemon + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" + "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error) { + unlock := d.lockMachine(machineID) + defer unlock() + + record, err := d.store.GetMachine(ctx, machineID) + if err != nil { + return nil, err + } + if record.Phase != contracthost.MachinePhaseRunning { + return nil, fmt.Errorf("machine %q is not running", machineID) + } + + // Reject if an operation is already pending for this machine + if ops, err := d.store.ListOperations(ctx); err == nil { + for _, op := range ops { + if op.MachineID == machineID { + return nil, fmt.Errorf("machine %q has a pending %q operation (started %s)", machineID, op.Type, op.StartedAt.Format(time.RFC3339)) + } + } + } + + snapshotID := contracthost.SnapshotID(generateID()) + + if err := d.store.UpsertOperation(ctx, model.OperationRecord{ + MachineID: machineID, + Type: model.MachineOperationSnapshot, + StartedAt: time.Now().UTC(), + SnapshotID: &snapshotID, + }); err != nil { + return nil, err + } + + clearOperation := false + defer func() { + if clearOperation { + _ = d.store.DeleteOperation(context.Background(), machineID) + } + }() + + snapshotDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID)) + if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + return nil, fmt.Errorf("create snapshot dir: %w", err) + } + + runtimeState := machineToRuntimeState(*record) + + // Pause the VM + if err := d.runtime.Pause(ctx, runtimeState); err != nil { + return nil, fmt.Errorf("pause machine %q: %w", machineID, err) + } + + // Write snapshot inside the chroot (Firecracker can only write there) + // Use jailed paths relative to the chroot root + chrootMemPath := "memory.bin" + chrootStatePath := "vmstate.bin" + + if err := d.runtime.CreateSnapshot(ctx, runtimeState, firecracker.SnapshotPaths{ + MemFilePath: chrootMemPath, + StateFilePath: chrootStatePath, + }); err != nil { + _ = d.runtime.Resume(ctx, runtimeState) + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("create snapshot for %q: %w", machineID, err) + } + + // COW-copy disk files while paused for consistency + var diskPaths []string + systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID) + if err != nil { + _ = d.runtime.Resume(ctx, runtimeState) + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("get system volume: %w", err) + } + systemDiskTarget := filepath.Join(snapshotDir, "system.img") + if err := cowCopyFile(systemVolume.Path, systemDiskTarget); err != nil { + _ = d.runtime.Resume(ctx, runtimeState) + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("copy system disk: %w", err) + } + diskPaths = append(diskPaths, systemDiskTarget) + + // Resume the source VM + if err := d.runtime.Resume(ctx, runtimeState); err != nil { + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("resume machine %q: %w", machineID, err) + } + + // Copy snapshot files from chroot to snapshot directory, then remove originals. + // os.Rename fails across filesystem boundaries (/proc//root/ is on procfs). + chrootRoot := filepath.Dir(filepath.Dir(runtimeState.SocketPath)) // strip /run/firecracker.socket + srcMemPath := filepath.Join(chrootRoot, chrootMemPath) + srcStatePath := filepath.Join(chrootRoot, chrootStatePath) + dstMemPath := filepath.Join(snapshotDir, "memory.bin") + dstStatePath := filepath.Join(snapshotDir, "vmstate.bin") + + if err := moveFile(srcMemPath, dstMemPath); err != nil { + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("move memory file: %w", err) + } + if err := moveFile(srcStatePath, dstStatePath); err != nil { + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("move vmstate file: %w", err) + } + + now := time.Now().UTC() + snapshotRecord := model.SnapshotRecord{ + ID: snapshotID, + MachineID: machineID, + Artifact: record.Artifact, + MemFilePath: dstMemPath, + StateFilePath: dstStatePath, + DiskPaths: diskPaths, + SourceRuntimeHost: record.RuntimeHost, + SourceTapDevice: record.TapDevice, + CreatedAt: now, + } + if err := d.store.CreateSnapshot(ctx, snapshotRecord); err != nil { + _ = os.RemoveAll(snapshotDir) + return nil, err + } + + clearOperation = true + return &contracthost.CreateSnapshotResponse{ + Snapshot: snapshotToContract(snapshotRecord), + }, nil +} + +func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) { + if err := validateMachineID(req.MachineID); err != nil { + return nil, err + } + + unlock := d.lockMachine(req.MachineID) + defer unlock() + + snap, err := d.store.GetSnapshot(ctx, snapshotID) + if err != nil { + return nil, err + } + + if _, err := d.store.GetMachine(ctx, req.MachineID); err == nil { + return nil, fmt.Errorf("machine %q already exists", req.MachineID) + } + + if err := d.store.UpsertOperation(ctx, model.OperationRecord{ + MachineID: req.MachineID, + Type: model.MachineOperationRestore, + StartedAt: time.Now().UTC(), + SnapshotID: &snapshotID, + }); err != nil { + return nil, err + } + + clearOperation := false + defer func() { + if clearOperation { + _ = d.store.DeleteOperation(context.Background(), req.MachineID) + } + }() + + sourceMachine, err := d.store.GetMachine(ctx, snap.MachineID) + switch { + case err == nil && sourceMachine.Phase == contracthost.MachinePhaseRunning: + clearOperation = true + return nil, fmt.Errorf("restore from snapshot %q while source machine %q is running is not supported yet", snapshotID, snap.MachineID) + case err != nil && err != store.ErrNotFound: + return nil, fmt.Errorf("get source machine for restore: %w", err) + } + + usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) + if err != nil { + return nil, err + } + restoreNetwork, err := restoreNetworkFromSnapshot(snap) + if err != nil { + clearOperation = true + return nil, err + } + if networkAllocationInUse(restoreNetwork, usedNetworks) { + clearOperation = true + return nil, fmt.Errorf("snapshot %q restore network %q (%s) is already in use", snapshotID, restoreNetwork.TapName, restoreNetwork.GuestIP()) + } + + artifact, err := d.store.GetArtifact(ctx, snap.Artifact) + if err != nil { + return nil, fmt.Errorf("get artifact for restore: %w", err) + } + + // COW-copy system disk from snapshot to new machine's disk dir. + newSystemDiskPath := d.systemVolumePath(req.MachineID) + if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil { + return nil, fmt.Errorf("create machine disk dir: %w", err) + } + if len(snap.DiskPaths) < 1 { + clearOperation = true + return nil, fmt.Errorf("snapshot %q has no disk paths", snapshotID) + } + if err := cowCopyFile(snap.DiskPaths[0], newSystemDiskPath); err != nil { + clearOperation = true + return nil, fmt.Errorf("copy system disk for restore: %w", err) + } + + loadSpec := firecracker.SnapshotLoadSpec{ + ID: firecracker.MachineID(req.MachineID), + SnapshotPath: snap.StateFilePath, + MemFilePath: snap.MemFilePath, + RootFSPath: newSystemDiskPath, + KernelImagePath: artifact.KernelImagePath, + DiskPaths: map[string]string{}, + Network: &restoreNetwork, + } + + machineState, err := d.runtime.RestoreBoot(ctx, loadSpec, usedNetworks) + if err != nil { + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("restore boot: %w", err) + } + + // Wait for guest to become ready + if err := waitForGuestReady(ctx, machineState.RuntimeHost, defaultMachinePorts()); err != nil { + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("wait for restored guest ready: %w", err) + } + if err := d.reconfigureGuestIdentity(ctx, machineState.RuntimeHost, req.MachineID); err != nil { + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("reconfigure restored guest identity: %w", err) + } + + systemVolumeID := d.systemVolumeID(req.MachineID) + now := time.Now().UTC() + + if err := d.store.CreateVolume(ctx, model.VolumeRecord{ + ID: systemVolumeID, + Kind: contracthost.VolumeKindSystem, + AttachedMachineID: machineIDPtr(req.MachineID), + SourceArtifact: &snap.Artifact, + Pool: model.StoragePoolMachineDisks, + Path: newSystemDiskPath, + CreatedAt: now, + }); err != nil { + return nil, err + } + + machineRecord := model.MachineRecord{ + ID: req.MachineID, + Artifact: snap.Artifact, + SystemVolumeID: systemVolumeID, + RuntimeHost: machineState.RuntimeHost, + TapDevice: machineState.TapName, + Ports: defaultMachinePorts(), + Phase: contracthost.MachinePhaseRunning, + PID: machineState.PID, + SocketPath: machineState.SocketPath, + CreatedAt: now, + StartedAt: machineState.StartedAt, + } + if err := d.store.CreateMachine(ctx, machineRecord); err != nil { + return nil, err + } + + clearOperation = true + return &contracthost.RestoreSnapshotResponse{ + Machine: machineToContract(machineRecord), + }, nil +} + +func (d *Daemon) GetSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error) { + snap, err := d.store.GetSnapshot(ctx, snapshotID) + if err != nil { + return nil, err + } + return &contracthost.GetSnapshotResponse{Snapshot: snapshotToContract(*snap)}, nil +} + +func (d *Daemon) ListSnapshots(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) { + records, err := d.store.ListSnapshotsByMachine(ctx, machineID) + if err != nil { + return nil, err + } + snapshots := make([]contracthost.Snapshot, 0, len(records)) + for _, r := range records { + snapshots = append(snapshots, snapshotToContract(r)) + } + return &contracthost.ListSnapshotsResponse{Snapshots: snapshots}, nil +} + +func (d *Daemon) DeleteSnapshotByID(ctx context.Context, snapshotID contracthost.SnapshotID) error { + snap, err := d.store.GetSnapshot(ctx, snapshotID) + if err != nil { + return err + } + snapshotDir := filepath.Dir(snap.MemFilePath) + if err := os.RemoveAll(snapshotDir); err != nil { + return fmt.Errorf("remove snapshot dir %q: %w", snapshotDir, err) + } + return d.store.DeleteSnapshot(ctx, snapshotID) +} + +func snapshotToContract(record model.SnapshotRecord) contracthost.Snapshot { + return contracthost.Snapshot{ + ID: record.ID, + MachineID: record.MachineID, + CreatedAt: record.CreatedAt, + } +} + +func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) { + if snap == nil { + return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required") + } + if strings.TrimSpace(snap.SourceRuntimeHost) == "" || strings.TrimSpace(snap.SourceTapDevice) == "" { + return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snap.ID) + } + network, err := firecracker.AllocationFromGuestIP(snap.SourceRuntimeHost, snap.SourceTapDevice) + if err != nil { + return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct snapshot %q network: %w", snap.ID, err) + } + return network, nil +} + +func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool { + targetTap := strings.TrimSpace(target.TapName) + for _, network := range used { + if network.GuestIP() == target.GuestIP() { + return true + } + if targetTap != "" && strings.TrimSpace(network.TapName) == targetTap { + return true + } + } + return false +} + +func generateID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("generate id: %v", err)) + } + return hex.EncodeToString(b) +} + +// moveFile copies src to dst then removes src. Works across filesystem boundaries +// unlike os.Rename, which is needed when moving files out of /proc//root/. +func moveFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + + if _, err := io.Copy(out, in); err != nil { + out.Close() + _ = os.Remove(dst) + return err + } + if err := out.Close(); err != nil { + _ = os.Remove(dst) + return err + } + return os.Remove(src) +} + +func cowCopyFile(source string, target string) error { + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return fmt.Errorf("create target dir for %q: %w", target, err) + } + cmd := exec.Command("cp", "--reflink=auto", "--sparse=always", source, target) + output, err := cmd.CombinedOutput() + if err != nil { + if cloneErr := cloneFile(source, target); cloneErr == nil { + return nil + } else { + return fmt.Errorf("cow copy %q to %q: cp failed: %w: %s; clone fallback failed: %w", source, target, err, string(output), cloneErr) + } + } + return nil +} diff --git a/internal/firecracker/api.go b/internal/firecracker/api.go index 90fb593..14b7674 100644 --- a/internal/firecracker/api.go +++ b/internal/firecracker/api.go @@ -146,6 +146,69 @@ func (c *apiClient) PutVsock(ctx context.Context, spec VsockSpec) error { return c.do(ctx, http.MethodPut, "/vsock", body, nil, http.StatusNoContent) } +type VmState string + +const ( + VmStatePaused VmState = "Paused" + VmStateResumed VmState = "Resumed" +) + +type vmRequest struct { + State VmState `json:"state"` +} + +type vmResponse struct { + State string `json:"state"` +} + +type SnapshotCreateParams struct { + MemFilePath string `json:"mem_file_path"` + SnapshotPath string `json:"snapshot_path"` + SnapshotType string `json:"snapshot_type"` +} + +type SnapshotLoadParams struct { + SnapshotPath string `json:"snapshot_path"` + MemBackend *MemBackend `json:"mem_backend,omitempty"` + ResumeVm bool `json:"resume_vm"` + NetworkOverrides []NetworkOverride `json:"network_overrides,omitempty"` + VsockOverride *VsockOverride `json:"vsock_override,omitempty"` +} + +type MemBackend struct { + BackendType string `json:"backend_type"` + BackendPath string `json:"backend_path"` +} + +type NetworkOverride struct { + IfaceID string `json:"iface_id"` + HostDevName string `json:"host_dev_name"` +} + +type VsockOverride struct { + UDSPath string `json:"uds_path"` +} + +func (c *apiClient) PatchVm(ctx context.Context, state VmState) error { + return c.do(ctx, http.MethodPatch, "/vm", vmRequest{State: state}, nil, http.StatusNoContent) +} + +func (c *apiClient) GetVm(ctx context.Context) (*vmResponse, error) { + var response vmResponse + if err := c.do(ctx, http.MethodGet, "/vm", nil, &response, http.StatusOK); err != nil { + return nil, err + } + return &response, nil +} + +func (c *apiClient) PutSnapshotCreate(ctx context.Context, params SnapshotCreateParams) error { + return c.do(ctx, http.MethodPut, "/snapshot/create", params, nil, http.StatusNoContent) +} + +func (c *apiClient) PutSnapshotLoad(ctx context.Context, params SnapshotLoadParams) error { + return c.do(ctx, http.MethodPut, "/snapshot/load", params, nil, http.StatusNoContent) +} + func (c *apiClient) do(ctx context.Context, method string, endpoint string, input any, output any, wantStatus int) error { var body io.Reader if input != nil { diff --git a/internal/firecracker/api_test.go b/internal/firecracker/api_test.go new file mode 100644 index 0000000..e336bc2 --- /dev/null +++ b/internal/firecracker/api_test.go @@ -0,0 +1,54 @@ +package firecracker + +import ( + "context" + "io" + "net/http" + "testing" +) + +func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) { + var ( + gotPath string + gotBody string + ) + + socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + gotPath = r.URL.Path + gotBody = string(body) + w.WriteHeader(http.StatusNoContent) + }) + defer shutdown() + + client := newAPIClient(socketPath) + err := client.PutSnapshotLoad(context.Background(), SnapshotLoadParams{ + SnapshotPath: "vmstate.bin", + MemBackend: &MemBackend{ + BackendType: "File", + BackendPath: "memory.bin", + }, + ResumeVm: false, + NetworkOverrides: []NetworkOverride{ + { + IfaceID: "net0", + HostDevName: "fctap7", + }, + }, + }) + if err != nil { + t.Fatalf("put snapshot load: %v", err) + } + + if gotPath != "/snapshot/load" { + t.Fatalf("request path mismatch: got %q want %q", gotPath, "/snapshot/load") + } + + want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}]}" + if gotBody != want { + t.Fatalf("request body mismatch:\n got: %s\nwant: %s", gotBody, want) + } +} diff --git a/internal/firecracker/launch.go b/internal/firecracker/launch.go index 641d78f..d682cd2 100644 --- a/internal/firecracker/launch.go +++ b/internal/firecracker/launch.go @@ -272,3 +272,11 @@ func stagedFileName(filePath string) (string, error) { } return name, nil } + +func stageSnapshotFile(sourcePath string, chrootRootDir string, name string) (string, error) { + target := filepath.Join(chrootRootDir, name) + if err := linkMachineFile(sourcePath, target); err != nil { + return "", err + } + return name, nil +} diff --git a/internal/firecracker/paths.go b/internal/firecracker/paths.go index 65119cb..0f02261 100644 --- a/internal/firecracker/paths.go +++ b/internal/firecracker/paths.go @@ -67,3 +67,18 @@ func buildMachinePaths(rootDir string, id MachineID, firecrackerBinaryPath strin func procSocketPath(pid int) string { return filepath.Join("/proc", strconv.Itoa(pid), "root", defaultFirecrackerSocketDir, defaultFirecrackerSocketName) } + +type snapshotPaths struct { + BaseDir string + MemFilePath string + StateFilePath string +} + +func buildSnapshotPaths(rootDir string, id string) snapshotPaths { + baseDir := filepath.Join(rootDir, "snapshots", id) + return snapshotPaths{ + BaseDir: baseDir, + MemFilePath: filepath.Join(baseDir, "memory.bin"), + StateFilePath: filepath.Join(baseDir, "vmstate.bin"), + } +} diff --git a/internal/firecracker/runtime.go b/internal/firecracker/runtime.go index 9c78c82..a4d3c86 100644 --- a/internal/firecracker/runtime.go +++ b/internal/firecracker/runtime.go @@ -220,6 +220,155 @@ func (r *Runtime) Delete(ctx context.Context, state MachineState) error { return nil } +func (r *Runtime) Pause(ctx context.Context, state MachineState) error { + client := newAPIClient(state.SocketPath) + return client.PatchVm(ctx, VmStatePaused) +} + +func (r *Runtime) Resume(ctx context.Context, state MachineState) error { + client := newAPIClient(state.SocketPath) + return client.PatchVm(ctx, VmStateResumed) +} + +func (r *Runtime) CreateSnapshot(ctx context.Context, state MachineState, paths SnapshotPaths) error { + client := newAPIClient(state.SocketPath) + return client.PutSnapshotCreate(ctx, SnapshotCreateParams{ + MemFilePath: paths.MemFilePath, + SnapshotPath: paths.StateFilePath, + SnapshotType: "Full", + }) +} + +func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, usedNetworks []NetworkAllocation) (*MachineState, error) { + cleanup := func(network NetworkAllocation, paths machinePaths, command *exec.Cmd, firecrackerPID int) { + if preserveFailureArtifacts() { + return + } + cleanupRunningProcess(firecrackerPID) + cleanupStartedProcess(command) + _ = r.networkProvisioner.Remove(context.Background(), network) + if paths.BaseDir != "" { + _ = os.RemoveAll(paths.BaseDir) + } + } + + var network NetworkAllocation + if loadSpec.Network != nil { + network = *loadSpec.Network + } else { + var err error + network, err = r.networkAllocator.Allocate(usedNetworks) + if err != nil { + return nil, err + } + } + + paths, err := buildMachinePaths(r.rootDir, loadSpec.ID, r.firecrackerBinaryPath) + if err != nil { + cleanup(network, machinePaths{}, nil, 0) + return nil, err + } + if err := os.MkdirAll(paths.LogDir, 0o755); err != nil { + cleanup(network, paths, nil, 0) + return nil, fmt.Errorf("create machine log dir %q: %w", paths.LogDir, err) + } + if err := r.networkProvisioner.Ensure(ctx, network); err != nil { + cleanup(network, paths, nil, 0) + return nil, err + } + + command, err := launchJailedFirecracker(paths, loadSpec.ID, r.firecrackerBinaryPath, r.jailerBinaryPath) + if err != nil { + cleanup(network, paths, nil, 0) + return nil, err + } + firecrackerPID, err := waitForPIDFile(ctx, paths.PIDFilePath) + if err != nil { + cleanup(network, paths, command, 0) + return nil, fmt.Errorf("wait for firecracker pid: %w", err) + } + + socketPath := procSocketPath(firecrackerPID) + client := newAPIClient(socketPath) + if err := waitForSocket(ctx, client, socketPath); err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("wait for firecracker socket: %w", err) + } + + // Stage snapshot files and disk images into the chroot + chrootMemPath, err := stageSnapshotFile(loadSpec.MemFilePath, paths.ChrootRootDir, "memory.bin") + if err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("stage memory file: %w", err) + } + chrootStatePath, err := stageSnapshotFile(loadSpec.SnapshotPath, paths.ChrootRootDir, "vmstate.bin") + if err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("stage vmstate file: %w", err) + } + + // Stage root filesystem + rootFSName, err := stagedFileName(loadSpec.RootFSPath) + if err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("rootfs path: %w", err) + } + if err := linkMachineFile(loadSpec.RootFSPath, filepath.Join(paths.ChrootRootDir, rootFSName)); err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("link rootfs into jail: %w", err) + } + + // Stage additional drives + for driveID, drivePath := range loadSpec.DiskPaths { + driveName, err := stagedFileName(drivePath) + if err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("drive %q path: %w", driveID, err) + } + if err := linkMachineFile(drivePath, filepath.Join(paths.ChrootRootDir, driveName)); err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("link drive %q into jail: %w", driveID, err) + } + } + + // Load snapshot (replaces the full configure+start sequence) + if err := client.PutSnapshotLoad(ctx, SnapshotLoadParams{ + SnapshotPath: chrootStatePath, + MemBackend: &MemBackend{ + BackendType: "File", + BackendPath: chrootMemPath, + }, + ResumeVm: false, + NetworkOverrides: []NetworkOverride{ + { + IfaceID: network.InterfaceID, + HostDevName: network.TapName, + }, + }, + }); err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("load snapshot: %w", err) + } + + // Resume the restored VM + if err := client.PatchVm(ctx, VmStateResumed); err != nil { + cleanup(network, paths, command, firecrackerPID) + return nil, fmt.Errorf("resume restored vm: %w", err) + } + + now := time.Now().UTC() + state := MachineState{ + ID: loadSpec.ID, + Phase: PhaseRunning, + PID: firecrackerPID, + RuntimeHost: network.GuestIP().String(), + SocketPath: socketPath, + TapName: network.TapName, + StartedAt: &now, + } + return &state, nil +} + func processExists(pid int) bool { if pid < 1 { return false diff --git a/internal/firecracker/state.go b/internal/firecracker/state.go index da37de7..d879c85 100644 --- a/internal/firecracker/state.go +++ b/internal/firecracker/state.go @@ -5,6 +5,27 @@ import "time" // Phase represents the lifecycle phase of a local microVM. type Phase string +// SnapshotPaths holds the file paths for a VM snapshot. +type SnapshotPaths struct { + MemFilePath string + StateFilePath string +} + +// SnapshotLoadSpec describes what is needed to restore a VM from a snapshot. +type SnapshotLoadSpec struct { + ID MachineID + SnapshotPath string + MemFilePath string + DiskPaths map[string]string // drive ID -> host path + RootFSPath string + KernelImagePath string + VCPUs int64 + MemoryMiB int64 + KernelArgs string + Vsock *VsockSpec + Network *NetworkAllocation +} + // MachineState describes the current host local state for a machine. type MachineState struct { ID MachineID diff --git a/internal/httpapi/handlers.go b/internal/httpapi/handlers.go index baa9df9..fc6ba7b 100644 --- a/internal/httpapi/handlers.go +++ b/internal/httpapi/handlers.go @@ -17,6 +17,11 @@ type Service interface { StopMachine(context.Context, contracthost.MachineID) error DeleteMachine(context.Context, contracthost.MachineID) error Health(context.Context) (*contracthost.HealthResponse, error) + CreateSnapshot(context.Context, contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error) + ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) + GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error) + DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error + RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) } type Handler struct { @@ -35,6 +40,7 @@ func (h *Handler) Routes() http.Handler { mux.HandleFunc("/health", h.handleHealth) mux.HandleFunc("/machines", h.handleMachines) mux.HandleFunc("/machines/", h.handleMachine) + mux.HandleFunc("/snapshots/", h.handleSnapshot) return mux } @@ -120,6 +126,80 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 2 && parts[1] == "snapshots" { + switch r.Method { + case http.MethodGet: + response, err := h.service.ListSnapshots(r.Context(), machineID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + case http.MethodPost: + response, err := h.service.CreateSnapshot(r.Context(), machineID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusCreated, response) + default: + writeMethodNotAllowed(w) + } + return + } + + writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) +} + +func (h *Handler) handleSnapshot(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/snapshots/") + if path == "" { + writeError(w, http.StatusNotFound, fmt.Errorf("snapshot id is required")) + return + } + parts := strings.Split(path, "/") + snapshotID := contracthost.SnapshotID(parts[0]) + + if len(parts) == 1 { + switch r.Method { + case http.MethodGet: + response, err := h.service.GetSnapshot(r.Context(), snapshotID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + case http.MethodDelete: + if err := h.service.DeleteSnapshotByID(r.Context(), snapshotID); err != nil { + writeError(w, statusForError(err), err) + return + } + w.WriteHeader(http.StatusNoContent) + default: + writeMethodNotAllowed(w) + } + return + } + + if len(parts) == 2 && parts[1] == "restore" { + if r.Method != http.MethodPost { + writeMethodNotAllowed(w) + return + } + var req contracthost.RestoreSnapshotRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + response, err := h.service.RestoreSnapshot(r.Context(), snapshotID, req) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusCreated, response) + return + } + writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) } diff --git a/internal/model/types.go b/internal/model/types.go index 8454b3d..a473af6 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -53,13 +53,28 @@ type VolumeRecord struct { type MachineOperation string const ( - MachineOperationCreate MachineOperation = "create" - MachineOperationStop MachineOperation = "stop" - MachineOperationDelete MachineOperation = "delete" + MachineOperationCreate MachineOperation = "create" + MachineOperationStop MachineOperation = "stop" + MachineOperationDelete MachineOperation = "delete" + MachineOperationSnapshot MachineOperation = "snapshot" + MachineOperationRestore MachineOperation = "restore" ) -type OperationRecord struct { - MachineID contracthost.MachineID - Type MachineOperation - StartedAt time.Time +type SnapshotRecord struct { + ID contracthost.SnapshotID + MachineID contracthost.MachineID + Artifact contracthost.ArtifactRef + MemFilePath string + StateFilePath string + DiskPaths []string + SourceRuntimeHost string + SourceTapDevice string + CreatedAt time.Time +} + +type OperationRecord struct { + MachineID contracthost.MachineID + Type MachineOperation + StartedAt time.Time + SnapshotID *contracthost.SnapshotID `json:"snapshot_id,omitempty"` } diff --git a/internal/store/file_store.go b/internal/store/file_store.go index bac36f1..aa00feb 100644 --- a/internal/store/file_store.go +++ b/internal/store/file_store.go @@ -23,9 +23,10 @@ type persistedOperations struct { } type persistedState struct { - Artifacts []model.ArtifactRecord `json:"artifacts"` - Machines []model.MachineRecord `json:"machines"` - Volumes []model.VolumeRecord `json:"volumes"` + Artifacts []model.ArtifactRecord `json:"artifacts"` + Machines []model.MachineRecord `json:"machines"` + Volumes []model.VolumeRecord `json:"volumes"` + Snapshots []model.SnapshotRecord `json:"snapshots"` } func NewFileStore(statePath string, operationsPath string) (*FileStore, error) { @@ -274,6 +275,73 @@ func (s *FileStore) DeleteOperation(_ context.Context, machineID contracthost.Ma }) } +func (s *FileStore) CreateSnapshot(_ context.Context, record model.SnapshotRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for _, snap := range state.Snapshots { + if snap.ID == record.ID { + return fmt.Errorf("store: snapshot %q already exists", record.ID) + } + } + state.Snapshots = append(state.Snapshots, record) + return nil + }) +} + +func (s *FileStore) GetSnapshot(_ context.Context, id contracthost.SnapshotID) (*model.SnapshotRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + for i := range state.Snapshots { + if state.Snapshots[i].ID == id { + record := state.Snapshots[i] + return &record, nil + } + } + return nil, ErrNotFound +} + +func (s *FileStore) ListSnapshotsByMachine(_ context.Context, machineID contracthost.MachineID) ([]model.SnapshotRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + var result []model.SnapshotRecord + for _, snap := range state.Snapshots { + if snap.MachineID == machineID { + result = append(result, snap) + } + } + if result == nil { + result = []model.SnapshotRecord{} + } + return result, nil +} + +func (s *FileStore) DeleteSnapshot(_ context.Context, id contracthost.SnapshotID) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.Snapshots { + if state.Snapshots[i].ID == id { + state.Snapshots = append(state.Snapshots[:i], state.Snapshots[i+1:]...) + return nil + } + } + return ErrNotFound + }) +} + func (s *FileStore) readOperations() (*persistedOperations, error) { var operations persistedOperations if err := readJSONFile(s.operationsPath, &operations); err != nil { @@ -387,6 +455,7 @@ func emptyPersistedState() persistedState { Artifacts: []model.ArtifactRecord{}, Machines: []model.MachineRecord{}, Volumes: []model.VolumeRecord{}, + Snapshots: []model.SnapshotRecord{}, } } @@ -404,6 +473,9 @@ func normalizeState(state *persistedState) { if state.Volumes == nil { state.Volumes = []model.VolumeRecord{} } + if state.Snapshots == nil { + state.Snapshots = []model.SnapshotRecord{} + } } func normalizeOperations(operations *persistedOperations) { diff --git a/internal/store/store.go b/internal/store/store.go index 542e6dd..f8e5fdf 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -27,4 +27,8 @@ type Store interface { UpsertOperation(context.Context, model.OperationRecord) error ListOperations(context.Context) ([]model.OperationRecord, error) DeleteOperation(context.Context, contracthost.MachineID) error + CreateSnapshot(context.Context, model.SnapshotRecord) error + GetSnapshot(context.Context, contracthost.SnapshotID) (*model.SnapshotRecord, error) + ListSnapshotsByMachine(context.Context, contracthost.MachineID) ([]model.SnapshotRecord, error) + DeleteSnapshot(context.Context, contracthost.SnapshotID) error }