diff --git a/contract/machines.go b/contract/machines.go index 43c7432..aec0ce0 100644 --- a/contract/machines.go +++ b/contract/machines.go @@ -17,6 +17,7 @@ type Machine struct { } type GuestConfig struct { + Hostname string `json:"hostname,omitempty"` AuthorizedKeys []string `json:"authorized_keys,omitempty"` TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"` LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"` diff --git a/contract/snapshots.go b/contract/snapshots.go index ea1377a..1192d6f 100644 --- a/contract/snapshots.go +++ b/contract/snapshots.go @@ -4,10 +4,31 @@ import "time" type SnapshotID string +type SnapshotArtifactKind string + +const ( + SnapshotArtifactKindMemory SnapshotArtifactKind = "memory" + SnapshotArtifactKindVMState SnapshotArtifactKind = "vmstate" + SnapshotArtifactKindDisk SnapshotArtifactKind = "disk" + SnapshotArtifactKindManifest SnapshotArtifactKind = "manifest" +) + type Snapshot struct { - ID SnapshotID `json:"id"` - MachineID MachineID `json:"machine_id"` - CreatedAt time.Time `json:"created_at"` + ID SnapshotID `json:"id"` + MachineID MachineID `json:"machine_id"` + SourceRuntimeHost string `json:"source_runtime_host,omitempty"` + SourceTapDevice string `json:"source_tap_device,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type SnapshotArtifact struct { + ID string `json:"id"` + Kind SnapshotArtifactKind `json:"kind"` + Name string `json:"name"` + SizeBytes int64 `json:"size_bytes"` + SHA256Hex string `json:"sha256_hex,omitempty"` + ObjectKey string `json:"object_key,omitempty"` + DownloadURL string `json:"download_url,omitempty"` } type CreateSnapshotRequest struct { @@ -15,7 +36,8 @@ type CreateSnapshotRequest struct { } type CreateSnapshotResponse struct { - Snapshot Snapshot `json:"snapshot"` + Snapshot Snapshot `json:"snapshot"` + Artifacts []SnapshotArtifact `json:"artifacts,omitempty"` } type GetSnapshotResponse struct { @@ -26,8 +48,52 @@ type ListSnapshotsResponse struct { Snapshots []Snapshot `json:"snapshots"` } +type SnapshotUploadPart struct { + PartNumber int32 `json:"part_number"` + OffsetBytes int64 `json:"offset_bytes"` + SizeBytes int64 `json:"size_bytes"` + UploadURL string `json:"upload_url"` +} + +type SnapshotArtifactUploadSession struct { + ArtifactID string `json:"artifact_id"` + ObjectKey string `json:"object_key"` + UploadID string `json:"upload_id"` + Parts []SnapshotUploadPart `json:"parts"` +} + +type UploadSnapshotRequest struct { + Artifacts []SnapshotArtifactUploadSession `json:"artifacts"` +} + +type UploadedSnapshotPart struct { + PartNumber int32 `json:"part_number"` + ETag string `json:"etag"` +} + +type UploadedSnapshotArtifact struct { + ArtifactID string `json:"artifact_id"` + CompletedParts []UploadedSnapshotPart `json:"completed_parts"` +} + +type UploadSnapshotResponse struct { + Artifacts []UploadedSnapshotArtifact `json:"artifacts"` +} + type RestoreSnapshotRequest struct { - MachineID MachineID `json:"machine_id"` + MachineID MachineID `json:"machine_id"` + Artifact ArtifactRef `json:"artifact"` + Snapshot DurableSnapshotSpec `json:"snapshot"` + GuestConfig *GuestConfig `json:"guest_config,omitempty"` +} + +type DurableSnapshotSpec struct { + SnapshotID SnapshotID `json:"snapshot_id"` + MachineID MachineID `json:"machine_id"` + ImageID string `json:"image_id"` + SourceRuntimeHost string `json:"source_runtime_host,omitempty"` + SourceTapDevice string `json:"source_tap_device,omitempty"` + Artifacts []SnapshotArtifact `json:"artifacts"` } type RestoreSnapshotResponse struct { diff --git a/contract/types.go b/contract/types.go index b13ed42..690a82a 100644 --- a/contract/types.go +++ b/contract/types.go @@ -9,9 +9,10 @@ type VolumeID string type VolumeKind string const ( - MachinePhaseRunning MachinePhase = "running" - MachinePhaseStopped MachinePhase = "stopped" - MachinePhaseFailed MachinePhase = "failed" + MachinePhaseStarting MachinePhase = "starting" + MachinePhaseRunning MachinePhase = "running" + MachinePhaseStopped MachinePhase = "stopped" + MachinePhaseFailed MachinePhase = "failed" ) const ( diff --git a/internal/daemon/create.go b/internal/daemon/create.go index 33e89c5..9449970 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -66,7 +66,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi if err := os.MkdirAll(filepath.Dir(systemVolumePath), 0o755); err != nil { return nil, fmt.Errorf("create system volume dir for %q: %w", req.MachineID, err) } - if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil { + if err := cowCopyFile(artifact.RootFSPath, systemVolumePath); err != nil { return nil, err } removeSystemVolumeOnFailure := true @@ -77,14 +77,8 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi _ = os.Remove(systemVolumePath) _ = os.RemoveAll(filepath.Dir(systemVolumePath)) }() - if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil { - return nil, err - } - if err := injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil { - return nil, err - } - spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath) + spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig) if err != nil { return nil, err } @@ -98,17 +92,6 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi return nil, err } - ports := defaultMachinePorts() - if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil { - _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) - if err != nil { - _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - now := time.Now().UTC() systemVolumeRecord := model.VolumeRecord{ ID: d.systemVolumeID(req.MachineID), @@ -143,44 +126,20 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi } record := model.MachineRecord{ - ID: req.MachineID, - Artifact: req.Artifact, - SystemVolumeID: systemVolumeRecord.ID, - UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), - RuntimeHost: state.RuntimeHost, - TapDevice: state.TapName, - Ports: ports, - GuestSSHPublicKey: guestSSHPublicKey, - Phase: contracthost.MachinePhaseRunning, - PID: state.PID, - SocketPath: state.SocketPath, - CreatedAt: now, - StartedAt: state.StartedAt, + ID: req.MachineID, + Artifact: req.Artifact, + GuestConfig: cloneGuestConfig(guestConfig), + SystemVolumeID: systemVolumeRecord.ID, + UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), + RuntimeHost: state.RuntimeHost, + TapDevice: state.TapName, + Ports: defaultMachinePorts(), + Phase: contracthost.MachinePhaseStarting, + PID: state.PID, + SocketPath: state.SocketPath, + CreatedAt: now, + StartedAt: state.StartedAt, } - d.relayAllocMu.Lock() - sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort) - var vncRelayPort uint16 - if err == nil { - vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort) - } - d.relayAllocMu.Unlock() - if err != nil { - d.stopMachineRelays(record.ID) - for _, volume := range userVolumes { - volume.AttachedMachineID = nil - _ = d.store.UpdateVolume(context.Background(), volume) - } - _ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID) - _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) - startedRelays := true - defer func() { - if startedRelays { - d.stopMachineRelays(record.ID) - } - }() if err := d.store.CreateMachine(ctx, record); err != nil { for _, volume := range userVolumes { volume.AttachedMachineID = nil @@ -192,12 +151,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi } removeSystemVolumeOnFailure = false - startedRelays = false clearOperation = true return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil } -func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string) (firecracker.MachineSpec, error) { +func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) { drives := make([]firecracker.DriveSpec, 0, len(userVolumes)) for i, volume := range userVolumes { drives = append(drives, firecracker.DriveSpec{ @@ -207,14 +165,25 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo }) } + mmds, err := d.guestMetadataSpec(machineID, guestConfig) + if err != nil { + return firecracker.MachineSpec{}, err + } spec := firecracker.MachineSpec{ ID: firecracker.MachineID(machineID), VCPUs: defaultGuestVCPUs, MemoryMiB: defaultGuestMemoryMiB, KernelImagePath: artifact.KernelImagePath, RootFSPath: systemVolumePath, - KernelArgs: defaultGuestKernelArgs, - Drives: drives, + RootDrive: firecracker.DriveSpec{ + ID: "root_drive", + Path: systemVolumePath, + CacheType: firecracker.DriveCacheTypeUnsafe, + IOEngine: firecracker.DriveIOEngineSync, + }, + KernelArgs: defaultGuestKernelArgs, + Drives: drives, + MMDS: mmds, } if err := spec.Validate(); err != nil { return firecracker.MachineSpec{}, err diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 11c7273..45bdb00 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -43,6 +43,7 @@ type Daemon struct { reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error readGuestSSHPublicKey func(context.Context, string) (string, error) + syncGuestFilesystem func(context.Context, string) error locksMu sync.Mutex machineLocks map[contracthost.MachineID]*sync.Mutex @@ -84,6 +85,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err } daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH daemon.readGuestSSHPublicKey = readGuestSSHPublicKey + daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH if err := daemon.ensureBackendSSHKeyPair(); err != nil { return nil, err } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 6b26694..05b329e 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -141,7 +141,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { t.Fatalf("create machine: %v", err) } - if response.Machine.Phase != contracthost.MachinePhaseRunning { + if response.Machine.Phase != contracthost.MachinePhaseStarting { t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase) } if response.Machine.RuntimeHost != "127.0.0.1" { @@ -169,29 +169,25 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if err != nil { t.Fatalf("read backend ssh public key: %v", err) } - authorizedKeys, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/authorized_keys") - if err != nil { - t.Fatalf("read injected authorized_keys: %v", err) + if runtime.lastSpec.MMDS == nil { + t.Fatalf("expected MMDS configuration on machine spec") } + if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 { + t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version) + } + payload, ok := runtime.lastSpec.MMDS.Data.(guestMetadataEnvelope) + if !ok { + t.Fatalf("mmds payload type mismatch: got %T", runtime.lastSpec.MMDS.Data) + } + if payload.Latest.MetaData.Hostname != "vm-1" { + t.Fatalf("mmds hostname mismatch: got %q", payload.Latest.MetaData.Hostname) + } + authorizedKeys := strings.Join(payload.Latest.MetaData.AuthorizedKeys, "\n") if !strings.Contains(authorizedKeys, strings.TrimSpace(string(hostAuthorizedKeyBytes))) { - t.Fatalf("authorized_keys missing backend ssh key: %q", authorizedKeys) + t.Fatalf("mmds authorized_keys missing backend ssh key: %q", authorizedKeys) } if !strings.Contains(authorizedKeys, "daemon-test") { - t.Fatalf("authorized_keys missing request override key: %q", authorizedKeys) - } - machineName, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/machine-name") - if err != nil { - t.Fatalf("read injected machine-name: %v", err) - } - if machineName != "vm-1\n" { - t.Fatalf("machine-name mismatch: got %q want %q", machineName, "vm-1\n") - } - hosts, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/hosts") - if err != nil { - t.Fatalf("read injected hosts: %v", err) - } - if !strings.Contains(hosts, "127.0.1.1 vm-1") { - t.Fatalf("hosts missing machine identity: %q", hosts) + t.Fatalf("mmds authorized_keys missing request override key: %q", authorizedKeys) } artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact) @@ -214,6 +210,12 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if machine.SystemVolumeID != "vm-1-system" { t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID) } + if machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("stored machine phase mismatch: got %q", machine.Phase) + } + if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 { + t.Fatalf("stored guest config missing authorized keys: %#v", machine.GuestConfig) + } operations, err := fileStore.ListOperations(context.Background()) if err != nil { @@ -224,6 +226,65 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { } } +func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + var syncedHost string + hostDaemon.syncGuestFilesystem = func(_ context.Context, runtimeHost string) error { + syncedHost = runtimeHost + return nil + } + + now := time.Now().UTC() + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: "vm-stop", + SystemVolumeID: "vm-stop-system", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap-stop", + Phase: contracthost.MachinePhaseRunning, + PID: 1234, + SocketPath: filepath.Join(root, "runtime", "vm-stop.sock"), + Ports: defaultMachinePorts(), + CreatedAt: now, + StartedAt: &now, + }); err != nil { + t.Fatalf("create machine: %v", err) + } + + if err := hostDaemon.StopMachine(context.Background(), "vm-stop"); err != nil { + t.Fatalf("stop machine: %v", err) + } + + if syncedHost != "172.16.0.2" { + t.Fatalf("sync host mismatch: got %q want %q", syncedHost, "172.16.0.2") + } + if len(runtime.deleteCalls) != 1 { + t.Fatalf("runtime delete call count mismatch: got %d want 1", len(runtime.deleteCalls)) + } + + stopped, err := fileStore.GetMachine(context.Background(), "vm-stop") + if err != nil { + t.Fatalf("get stopped machine: %v", err) + } + if stopped.Phase != contracthost.MachinePhaseStopped { + t.Fatalf("machine phase mismatch: got %q want %q", stopped.Phase, contracthost.MachinePhaseStopped) + } + if stopped.RuntimeHost != "" { + t.Fatalf("runtime host should be cleared after stop, got %q", stopped.RuntimeHost) + } +} + func TestNewEnsuresBackendSSHKeyPair(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -249,7 +310,7 @@ func TestNewEnsuresBackendSSHKeyPair(t *testing.T) { } } -func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { +func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { root := t.TempDir() cfg := testConfig(root) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -257,7 +318,23 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { t.Fatalf("create file store: %v", err) } - runtime := &fakeRuntime{} + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { _ = sshListener.Close() }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { _ = vncListener.Close() }() + + startedAt := time.Unix(1700000099, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "restored", + Phase: firecracker.PhaseRunning, + PID: 1234, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored", "root", "run", "firecracker.sock"), + TapName: "fctap0", + StartedAt: &startedAt, + }, + } hostDaemon, err := New(cfg, fileStore, runtime) if err != nil { t.Fatalf("create daemon: %v", err) @@ -281,32 +358,13 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { t.Fatalf("put artifact: %v", err) } - if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ - ID: "source", - Artifact: artifactRef, - SystemVolumeID: "source-system", - RuntimeHost: "172.16.0.2", - TapDevice: "fctap0", - Phase: contracthost.MachinePhaseRunning, - CreatedAt: time.Now().UTC(), - }); err != nil { - t.Fatalf("create source machine: %v", err) - } - - snapDisk := filepath.Join(root, "snapshots", "snap1", "system.img") - if err := os.MkdirAll(filepath.Dir(snapDisk), 0o755); err != nil { - t.Fatalf("create snapshot dir: %v", err) - } - if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { - t.Fatalf("write snapshot disk: %v", err) - } if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ ID: "snap1", MachineID: "source", Artifact: artifactRef, MemFilePath: filepath.Join(root, "snapshots", "snap1", "memory.bin"), StateFilePath: filepath.Join(root, "snapshots", "snap1", "vmstate.bin"), - DiskPaths: []string{snapDisk}, + DiskPaths: []string{filepath.Join(root, "snapshots", "snap1", "system.img")}, SourceRuntimeHost: "172.16.0.2", SourceTapDevice: "fctap0", CreatedAt: time.Now().UTC(), @@ -314,17 +372,49 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { t.Fatalf("create snapshot: %v", err) } - _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ - MachineID: "restored", + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/memory": []byte("mem"), + "/vmstate": []byte("state"), + "/system": []byte("disk"), }) - if err == nil { - t.Fatal("expected restore rejection while source is running") + defer server.Close() + + response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ + MachineID: "restored", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + Snapshot: contracthost.DurableSnapshotSpec{ + SnapshotID: "snap1", + MachineID: "source", + ImageID: "image-1", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"}, + {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"}, + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, + }, + }, + }) + if err != nil { + t.Fatalf("restore snapshot: %v", err) } - if !strings.Contains(err.Error(), `source machine "source" is running`) { - t.Fatalf("unexpected restore error: %v", err) + if response.Machine.ID != "restored" { + t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) } - if runtime.restoreCalls != 0 { - t.Fatalf("restore boot should not run when source machine is still running: got %d", runtime.restoreCalls) + if runtime.restoreCalls != 1 { + t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) + } + if runtime.lastLoadSpec.Network == nil { + t.Fatalf("restore boot should preserve snapshot network") + } + if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { + t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2") + } + if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" { + t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0") } ops, err := fileStore.ListOperations(context.Background()) @@ -332,11 +422,11 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) { t.Fatalf("list operations: %v", err) } if len(ops) != 0 { - t.Fatalf("operation journal should be empty after handled restore rejection: got %d entries", len(ops)) + t.Fatalf("operation journal should be empty after successful restore: got %d entries", len(ops)) } } -func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { +func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { root := t.TempDir() cfg := testConfig(root) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -378,56 +468,35 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { return nil } - artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} - kernelPath := filepath.Join(root, "artifact-kernel") - rootFSPath := filepath.Join(root, "artifact-rootfs") - if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil { - t.Fatalf("write kernel: %v", err) - } - if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil { - t.Fatalf("write rootfs: %v", err) - } - if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{ - Ref: artifactRef, - LocalKey: "artifact", - LocalDir: filepath.Join(root, "artifact"), - KernelImagePath: kernelPath, - RootFSPath: rootFSPath, - CreatedAt: time.Now().UTC(), - }); err != nil { - t.Fatalf("put artifact: %v", err) - } - - snapDir := filepath.Join(root, "snapshots", "snap1") - if err := os.MkdirAll(snapDir, 0o755); err != nil { - t.Fatalf("create snapshot dir: %v", err) - } - snapDisk := filepath.Join(snapDir, "system.img") - if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { - t.Fatalf("write snapshot disk: %v", err) - } - if err := os.WriteFile(filepath.Join(snapDir, "memory.bin"), []byte("mem"), 0o644); err != nil { - t.Fatalf("write memory snapshot: %v", err) - } - if err := os.WriteFile(filepath.Join(snapDir, "vmstate.bin"), []byte("state"), 0o644); err != nil { - t.Fatalf("write vmstate snapshot: %v", err) - } - if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ - ID: "snap1", - MachineID: "source", - Artifact: artifactRef, - MemFilePath: filepath.Join(snapDir, "memory.bin"), - StateFilePath: filepath.Join(snapDir, "vmstate.bin"), - DiskPaths: []string{snapDisk}, - SourceRuntimeHost: "172.16.0.2", - SourceTapDevice: "fctap0", - CreatedAt: time.Now().UTC(), - }); err != nil { - t.Fatalf("create snapshot: %v", err) - } + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/memory": []byte("mem"), + "/vmstate": []byte("state"), + "/system": []byte("disk"), + "/user-0": []byte("user-disk"), + }) + defer server.Close() response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ MachineID: "restored", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + Snapshot: contracthost.DurableSnapshotSpec{ + SnapshotID: "snap1", + MachineID: "source", + ImageID: "image-1", + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"}, + {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"}, + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, + {ID: "disk-user-0", Kind: contracthost.SnapshotArtifactKindDisk, Name: "user-0.img", DownloadURL: server.URL + "/user-0"}, + }, + }, }) if err != nil { t.Fatalf("restore snapshot: %v", err) @@ -439,13 +508,19 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls) } if runtime.lastLoadSpec.Network == nil { - t.Fatal("restore boot did not receive snapshot network") + t.Fatalf("restore boot should preserve durable snapshot network") } if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" { - t.Fatalf("restored guest network mismatch: got %q want %q", got, "172.16.0.2") + t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2") } - if runtime.lastLoadSpec.KernelImagePath != kernelPath { - t.Fatalf("restore boot kernel path mismatch: got %q want %q", runtime.lastLoadSpec.KernelImagePath, kernelPath) + if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" { + t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0") + } + if !strings.Contains(runtime.lastLoadSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }), "kernel")) { + t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath) } if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" { t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine) @@ -458,6 +533,12 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { if machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("restored machine phase mismatch: got %q", machine.Phase) } + if len(machine.UserVolumeIDs) != 1 { + t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs) + } + if _, err := os.Stat(filepath.Join(cfg.MachineDisksDir, "restored", "user-0.img")); err != nil { + t.Fatalf("restored user disk missing: %v", err) + } ops, err := fileStore.ListOperations(context.Background()) if err != nil { @@ -468,6 +549,67 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { } } +func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: "source", + Artifact: contracthost.ArtifactRef{KernelImageURL: "https://example.com/kernel", RootFSURL: "https://example.com/rootfs"}, + SystemVolumeID: "source-system", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap0", + Phase: contracthost.MachinePhaseRunning, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create running source machine: %v", err) + } + + _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{ + MachineID: "restored", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: "https://example.com/kernel", + RootFSURL: "https://example.com/rootfs", + }, + Snapshot: contracthost.DurableSnapshotSpec{ + SnapshotID: "snap1", + MachineID: "source", + ImageID: "image-1", + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + }, + }) + if err == nil || !strings.Contains(err.Error(), "still in use on this host") { + t.Fatalf("restore snapshot error = %v, want restore network in-use failure", err) + } + if runtime.restoreCalls != 0 { + t.Fatalf("restore boot should not be attempted, got %d calls", runtime.restoreCalls) + } +} + +func newRestoreArtifactServer(t *testing.T, payloads map[string][]byte) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + payload, ok := payloads[r.URL.Path] + if !ok { + http.NotFound(w, r) + return + } + _, _ = w.Write(payload) + })) +} + func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) { t.Parallel() diff --git a/internal/daemon/files.go b/internal/daemon/files.go index 9c6b9e5..4adfac6 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -414,6 +414,8 @@ func publishedPortToContract(record model.PublishedPortRecord) contracthost.Publ func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState { phase := firecracker.PhaseStopped switch record.Phase { + case contracthost.MachinePhaseStarting: + phase = firecracker.PhaseRunning case contracthost.MachinePhaseRunning: phase = firecracker.PhaseRunning case contracthost.MachinePhaseFailed: diff --git a/internal/daemon/guest_identity.go b/internal/daemon/guest_identity.go index bcd6afe..59d98b1 100644 --- a/internal/daemon/guest_identity.go +++ b/internal/daemon/guest_identity.go @@ -54,6 +54,31 @@ hostname "$machine_name" >/dev/null 2>&1 || true return nil } +func (d *Daemon) syncGuestFilesystemOverSSH(ctx context.Context, runtimeHost string) error { + runtimeHost = strings.TrimSpace(runtimeHost) + if runtimeHost == "" { + return fmt.Errorf("guest runtime host is required") + } + + cmd := exec.CommandContext( + ctx, + "ssh", + "-i", d.backendSSHPrivateKeyPath(), + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "IdentitiesOnly=yes", + "-o", "BatchMode=yes", + "-p", strconv.Itoa(int(defaultSSHPort)), + "node@"+runtimeHost, + "sudo bash -lc "+shellSingleQuote("sync"), + ) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("sync guest filesystem over ssh: %w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + func shellSingleQuote(value string) string { return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" } diff --git a/internal/daemon/guest_metadata.go b/internal/daemon/guest_metadata.go new file mode 100644 index 0000000..08c4ad3 --- /dev/null +++ b/internal/daemon/guest_metadata.go @@ -0,0 +1,80 @@ +package daemon + +import ( + "fmt" + "strings" + + "github.com/getcompanion-ai/computer-host/internal/firecracker" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +const ( + defaultMMDSIPv4Address = "169.254.170.2" + defaultMMDSPayloadVersion = "v1" +) + +type guestMetadataEnvelope struct { + Latest guestMetadataRoot `json:"latest"` +} + +type guestMetadataRoot struct { + MetaData guestMetadataPayload `json:"meta-data"` +} + +type guestMetadataPayload struct { + Version string `json:"version"` + MachineID string `json:"machine_id"` + Hostname string `json:"hostname"` + AuthorizedKeys []string `json:"authorized_keys,omitempty"` + TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"` + LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"` +} + +func cloneGuestConfig(config *contracthost.GuestConfig) *contracthost.GuestConfig { + if config == nil { + return nil + } + cloned := &contracthost.GuestConfig{ + AuthorizedKeys: append([]string(nil), config.AuthorizedKeys...), + TrustedUserCAKeys: append([]string(nil), config.TrustedUserCAKeys...), + } + if config.LoginWebhook != nil { + copy := *config.LoginWebhook + cloned.LoginWebhook = © + } + return cloned +} + +func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) { + name := strings.TrimSpace(string(machineID)) + if name == "" { + return nil, fmt.Errorf("machine id is required") + } + + payload := guestMetadataEnvelope{ + Latest: guestMetadataRoot{ + MetaData: guestMetadataPayload{ + Version: defaultMMDSPayloadVersion, + MachineID: name, + Hostname: name, + AuthorizedKeys: nil, + TrustedUserCAKeys: nil, + }, + }, + } + if guestConfig != nil { + payload.Latest.MetaData.AuthorizedKeys = append([]string(nil), guestConfig.AuthorizedKeys...) + payload.Latest.MetaData.TrustedUserCAKeys = append([]string(nil), guestConfig.TrustedUserCAKeys...) + if guestConfig.LoginWebhook != nil { + loginWebhook := *guestConfig.LoginWebhook + payload.Latest.MetaData.LoginWebhook = &loginWebhook + } + } + + return &firecracker.MMDSSpec{ + NetworkInterfaces: []string{"net0"}, + Version: firecracker.MMDSVersionV2, + IPv4Address: defaultMMDSIPv4Address, + Data: payload, + }, nil +} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index dcfffb7..6b54126 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -48,9 +48,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* if err != nil { return nil, err } + previousRecord := *record if record.Phase == contracthost.MachinePhaseRunning { return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil } + if record.Phase == contracthost.MachinePhaseStarting { + return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil + } if record.Phase != contracthost.MachinePhaseStopped { return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase) } @@ -82,7 +86,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* if err != nil { return nil, err } - spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path) + spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig) if err != nil { return nil, err } @@ -94,39 +98,18 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* if err != nil { return nil, err } - ports := defaultMachinePorts() - if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil { - _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) - if err != nil { - _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - record.RuntimeHost = state.RuntimeHost record.TapDevice = state.TapName - record.Ports = ports - record.GuestSSHPublicKey = guestSSHPublicKey - record.Phase = contracthost.MachinePhaseRunning + record.Ports = defaultMachinePorts() + record.GuestSSHPublicKey = "" + record.Phase = contracthost.MachinePhaseStarting record.Error = "" record.PID = state.PID record.SocketPath = state.SocketPath record.StartedAt = state.StartedAt if err := d.store.UpdateMachine(ctx, *record); err != nil { _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - if err := d.ensureMachineRelays(ctx, record); err != nil { - d.stopMachineRelays(id) - _ = d.runtime.Delete(context.Background(), *state) - return nil, err - } - if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { - d.stopMachineRelays(id) - d.stopPublishedPortsForMachine(id) - _ = d.runtime.Delete(context.Background(), *state) + _ = d.store.UpdateMachine(context.Background(), previousRecord) return nil, err } @@ -272,7 +255,10 @@ func (d *Daemon) listRunningNetworks(ctx context.Context, ignore contracthost.Ma networks := make([]firecracker.NetworkAllocation, 0, len(records)) for _, record := range records { - if record.ID == ignore || record.Phase != contracthost.MachinePhaseRunning { + if record.ID == ignore { + continue + } + if record.Phase != contracthost.MachinePhaseRunning && record.Phase != contracthost.MachinePhaseStarting { continue } if strings.TrimSpace(record.RuntimeHost) == "" || strings.TrimSpace(record.TapDevice) == "" { @@ -337,11 +323,8 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach if err != nil { return err } - if record.Phase == contracthost.MachinePhaseRunning { - if err := d.ensureMachineRelays(ctx, record); err != nil { - return err - } - if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { + if record.Phase == contracthost.MachinePhaseRunning || record.Phase == contracthost.MachinePhaseStarting { + if _, err := d.reconcileMachine(ctx, machineID); err != nil { return err } return d.store.DeleteOperation(ctx, machineID) @@ -385,7 +368,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma if err != nil { return nil, err } - if record.Phase != contracthost.MachinePhaseRunning { + if record.Phase != contracthost.MachinePhaseRunning && record.Phase != contracthost.MachinePhaseStarting { return record, nil } @@ -393,6 +376,42 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma if err != nil { return nil, err } + if record.Phase == contracthost.MachinePhaseStarting { + if state.Phase != firecracker.PhaseRunning { + return d.failMachineStartup(ctx, record, state.Error) + } + ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts()) + if err != nil { + return nil, err + } + if !ready { + return record, nil + } + guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) + if err != nil { + return d.failMachineStartup(ctx, record, err.Error()) + } + record.RuntimeHost = state.RuntimeHost + record.TapDevice = state.TapName + record.Ports = defaultMachinePorts() + record.GuestSSHPublicKey = guestSSHPublicKey + record.Phase = contracthost.MachinePhaseRunning + record.Error = "" + record.PID = state.PID + record.SocketPath = state.SocketPath + record.StartedAt = state.StartedAt + if err := d.store.UpdateMachine(ctx, *record); err != nil { + return nil, err + } + if err := d.ensureMachineRelays(ctx, record); err != nil { + return d.failMachineStartup(ctx, record, err.Error()) + } + if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { + d.stopMachineRelays(record.ID) + return d.failMachineStartup(ctx, record, err.Error()) + } + return record, nil + } if state.Phase == firecracker.PhaseRunning { if err := d.ensureMachineRelays(ctx, record); err != nil { return nil, err @@ -418,6 +437,28 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma return record, nil } +func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRecord, failureReason string) (*model.MachineRecord, error) { + if record == nil { + return nil, fmt.Errorf("machine record is required") + } + _ = d.runtime.Delete(ctx, machineToRuntimeState(*record)) + d.stopMachineRelays(record.ID) + d.stopPublishedPortsForMachine(record.ID) + record.Phase = contracthost.MachinePhaseFailed + record.Error = strings.TrimSpace(failureReason) + record.Ports = defaultMachinePorts() + record.GuestSSHPublicKey = "" + record.PID = 0 + record.SocketPath = "" + record.RuntimeHost = "" + record.TapDevice = "" + record.StartedAt = nil + if err := d.store.UpdateMachine(ctx, *record); err != nil { + return nil, err + } + return record, nil +} + func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error { d.stopMachineRelays(record.ID) d.stopPublishedPortsForMachine(record.ID) @@ -450,6 +491,11 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR } func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error { + if record.Phase == contracthost.MachinePhaseRunning && strings.TrimSpace(record.RuntimeHost) != "" { + if err := d.syncGuestFilesystem(ctx, record.RuntimeHost); err != nil { + fmt.Fprintf(os.Stderr, "warning: sync guest filesystem for %q failed before stop: %v\n", record.ID, err) + } + } d.stopMachineRelays(record.ID) d.stopPublishedPortsForMachine(record.ID) if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { diff --git a/internal/daemon/machine_relays.go b/internal/daemon/machine_relays.go index 91fc499..cb8801c 100644 --- a/internal/daemon/machine_relays.go +++ b/internal/daemon/machine_relays.go @@ -56,6 +56,9 @@ func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contractho if record.ID == machineID { continue } + if record.Phase != contracthost.MachinePhaseRunning { + continue + } if port := machineRelayHostPort(record, name); port != 0 { used[port] = struct{}{} } diff --git a/internal/daemon/readiness.go b/internal/daemon/readiness.go index 1a67f86..63644b9 100644 --- a/internal/daemon/readiness.go +++ b/internal/daemon/readiness.go @@ -30,15 +30,15 @@ func waitForGuestReady(ctx context.Context, host string, ports []contracthost.Ma func waitForGuestPort(ctx context.Context, host string, port contracthost.MachinePort) error { address := net.JoinHostPort(host, strconv.Itoa(int(port.Port))) - dialer := net.Dialer{Timeout: defaultGuestDialTimeout} ticker := time.NewTicker(defaultGuestReadyPollInterval) defer ticker.Stop() var lastErr error for { - connection, err := dialer.DialContext(ctx, string(port.Protocol), address) - if err == nil { - _ = connection.Close() + probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout) + ready, err := guestPortReady(probeCtx, host, port) + cancel() + if err == nil && ready { return nil } lastErr = err @@ -50,3 +50,38 @@ func waitForGuestPort(ctx context.Context, host string, port contracthost.Machin } } } + +func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) { + host = strings.TrimSpace(host) + if host == "" { + return false, fmt.Errorf("guest runtime host is required") + } + + for _, port := range ports { + probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout) + ready, err := guestPortReady(probeCtx, host, port) + cancel() + if err != nil { + return false, err + } + if !ready { + return false, nil + } + } + return true, nil +} + +func guestPortReady(ctx context.Context, host string, port contracthost.MachinePort) (bool, error) { + address := net.JoinHostPort(host, strconv.Itoa(int(port.Port))) + dialer := net.Dialer{Timeout: defaultGuestDialTimeout} + + connection, err := dialer.DialContext(ctx, string(port.Protocol), address) + if err == nil { + _ = connection.Close() + return true, nil + } + if ctx.Err() != nil { + return false, nil + } + return false, nil +} diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 6d3bf2f..f297533 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -2,7 +2,10 @@ package daemon import ( "context" + "crypto/sha256" + "encoding/hex" "errors" + "fmt" "net" "os" "path/filepath" @@ -54,6 +57,19 @@ func (s machineLookupErrorStore) GetMachine(context.Context, contracthost.Machin return nil, s.err } +type relayExhaustionStore struct { + hoststore.Store + extraMachines []model.MachineRecord +} + +func (s relayExhaustionStore) ListMachines(ctx context.Context) ([]model.MachineRecord, error) { + machines, err := s.Store.ListMachines(ctx) + if err != nil { + return nil, err + } + return append(machines, s.extraMachines...), nil +} + type publishedPortResult struct { response *contracthost.CreatePublishedPortResponse err error @@ -283,6 +299,249 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) assertOperationCount(t, baseStore, 1) } +func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + exhaustedStore := relayExhaustionStore{ + Store: baseStore, + extraMachines: exhaustedMachineRelayRecords(), + } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { + _ = sshListener.Close() + }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { + _ = vncListener.Close() + }() + + startedAt := time.Unix(1700000200, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "vm-start", + Phase: firecracker.PhaseRunning, + PID: 9999, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-start", "root", "run", "firecracker.sock"), + TapName: "fctap-start", + StartedAt: &startedAt, + }, + } + + hostDaemon, err := New(cfg, exhaustedStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + + artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} + kernelPath := filepath.Join(root, "artifact-kernel") + rootFSPath := filepath.Join(root, "artifact-rootfs") + systemVolumePath := filepath.Join(root, "machine-disks", "vm-start", "rootfs.ext4") + for _, file := range []string{kernelPath, rootFSPath, systemVolumePath} { + if err := os.MkdirAll(filepath.Dir(file), 0o755); err != nil { + t.Fatalf("mkdir for %q: %v", file, err) + } + if err := os.WriteFile(file, []byte("payload"), 0o644); err != nil { + t.Fatalf("write file %q: %v", file, err) + } + } + if err := baseStore.PutArtifact(context.Background(), model.ArtifactRecord{ + Ref: artifactRef, + LocalKey: "artifact", + LocalDir: filepath.Join(root, "artifact"), + KernelImagePath: kernelPath, + RootFSPath: rootFSPath, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("put artifact: %v", err) + } + if err := baseStore.CreateVolume(context.Background(), model.VolumeRecord{ + ID: "vm-start-system", + Kind: contracthost.VolumeKindSystem, + Pool: model.StoragePoolMachineDisks, + Path: systemVolumePath, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create system volume: %v", err) + } + if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: "vm-start", + Artifact: artifactRef, + SystemVolumeID: "vm-start-system", + Ports: defaultMachinePorts(), + Phase: contracthost.MachinePhaseStopped, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create machine: %v", err) + } + + response, err := hostDaemon.StartMachine(context.Background(), "vm-start") + if err != nil { + t.Fatalf("StartMachine error = %v", err) + } + if response.Machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting) + } + + machine, err := baseStore.GetMachine(context.Background(), "vm-start") + if err != nil { + t.Fatalf("get machine: %v", err) + } + if machine.Phase != contracthost.MachinePhaseStarting { + t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting) + } + if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" { + t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice) + } + if machine.PID != 9999 || machine.SocketPath == "" || machine.StartedAt == nil { + t.Fatalf("machine process state missing: pid=%d socket=%q started_at=%v", machine.PID, machine.SocketPath, machine.StartedAt) + } + if len(runtime.deleteCalls) != 0 { + t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls)) + } +} + +func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + exhaustedStore := relayExhaustionStore{ + Store: baseStore, + extraMachines: exhaustedMachineRelayRecords(), + } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { + _ = sshListener.Close() + }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { + _ = vncListener.Close() + }() + + startedAt := time.Unix(1700000300, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "restored-exhausted", + Phase: firecracker.PhaseRunning, + PID: 8888, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-exhausted", "root", "run", "firecracker.sock"), + TapName: "fctap-restore", + StartedAt: &startedAt, + }, + } + + hostDaemon, err := New(cfg, exhaustedStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil } + + artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"} + kernelPath := filepath.Join(root, "artifact-kernel") + rootFSPath := filepath.Join(root, "artifact-rootfs") + if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil { + t.Fatalf("write kernel: %v", err) + } + if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil { + t.Fatalf("write rootfs: %v", err) + } + if err := baseStore.PutArtifact(context.Background(), model.ArtifactRecord{ + Ref: artifactRef, + LocalKey: "artifact", + LocalDir: filepath.Join(root, "artifact"), + KernelImagePath: kernelPath, + RootFSPath: rootFSPath, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("put artifact: %v", err) + } + + snapDir := filepath.Join(root, "snapshots", "snap-exhausted") + if err := os.MkdirAll(snapDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + snapDisk := filepath.Join(snapDir, "system.img") + if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil { + t.Fatalf("write snapshot disk: %v", err) + } + memPath := filepath.Join(snapDir, "memory.bin") + if err := os.WriteFile(memPath, []byte("mem"), 0o644); err != nil { + t.Fatalf("write memory snapshot: %v", err) + } + statePath := filepath.Join(snapDir, "vmstate.bin") + if err := os.WriteFile(statePath, []byte("state"), 0o644); err != nil { + t.Fatalf("write vmstate snapshot: %v", err) + } + if err := baseStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: "snap-exhausted", + MachineID: "source", + Artifact: artifactRef, + MemFilePath: memPath, + StateFilePath: statePath, + DiskPaths: []string{snapDisk}, + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/memory": []byte("mem"), + "/vmstate": []byte("state"), + "/system": []byte("disk"), + }) + defer server.Close() + + _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{ + MachineID: "restored-exhausted", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + Snapshot: contracthost.DurableSnapshotSpec{ + SnapshotID: "snap-exhausted", + MachineID: "source", + ImageID: "image-1", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"}, + {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"}, + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, + }, + }, + }) + if err == nil || !strings.Contains(err.Error(), "allocate relay ports for restored machine") { + t.Fatalf("RestoreSnapshot error = %v, want relay allocation failure", err) + } + + if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); !errors.Is(err, hoststore.ErrNotFound) { + t.Fatalf("restored system volume record should be deleted, get err = %v", err) + } + if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); !os.IsNotExist(err) { + t.Fatalf("restored system disk should be removed, stat err = %v", err) + } + if len(runtime.deleteCalls) != 1 { + t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls)) + } + assertOperationCount(t, baseStore, 0) +} + func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifacts(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -339,6 +598,193 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact assertOperationCount(t, fileStore, 0) } +func TestStopMachineContinuesWhenGuestSyncFails(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + hostDaemon.syncGuestFilesystem = func(context.Context, string) error { + return errors.New("guest sync failed") + } + + now := time.Now().UTC() + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: "vm-stop-fail", + SystemVolumeID: "vm-stop-fail-system", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap-stop-fail", + Phase: contracthost.MachinePhaseRunning, + PID: 1234, + SocketPath: filepath.Join(root, "runtime", "vm-stop-fail.sock"), + Ports: defaultMachinePorts(), + CreatedAt: now, + StartedAt: &now, + }); err != nil { + t.Fatalf("create machine: %v", err) + } + + if err := hostDaemon.StopMachine(context.Background(), "vm-stop-fail"); err != nil { + t.Fatalf("StopMachine returned error despite sync failure: %v", err) + } + if len(runtime.deleteCalls) != 1 { + t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls)) + } + + machine, err := fileStore.GetMachine(context.Background(), "vm-stop-fail") + if err != nil { + t.Fatalf("get machine: %v", err) + } + if machine.Phase != contracthost.MachinePhaseStopped { + t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStopped) + } +} + +func TestOrderedRestoredUserDiskArtifactsSortsByDriveIndex(t *testing.T) { + ordered := orderedRestoredUserDiskArtifacts(map[string]restoredSnapshotArtifact{ + "user-10.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-10.img"}}, + "user-2.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-2.img"}}, + "user-1.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-1.img"}}, + "system.img": {Artifact: contracthost.SnapshotArtifact{Name: "system.img"}}, + }) + + names := make([]string, 0, len(ordered)) + for _, artifact := range ordered { + names = append(names, artifact.Artifact.Name) + } + if got, want := strings.Join(names, ","), "user-1.img,user-2.img,user-10.img"; got != want { + t.Fatalf("ordered restored artifacts = %q, want %q", got, want) + } +} + +func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + sshListener := listenTestPort(t, int(defaultSSHPort)) + defer func() { _ = sshListener.Close() }() + vncListener := listenTestPort(t, int(defaultVNCPort)) + defer func() { _ = vncListener.Close() }() + + startedAt := time.Unix(1700000400, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "restored-clean", + Phase: firecracker.PhaseRunning, + PID: 7777, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-clean", "root", "run", "firecracker.sock"), + TapName: "fctap-clean", + StartedAt: &startedAt, + }, + } + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil } + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/memory": []byte("mem"), + "/vmstate": []byte("state"), + "/system": []byte("disk"), + }) + defer server.Close() + + _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-clean", contracthost.RestoreSnapshotRequest{ + MachineID: "restored-clean", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + Snapshot: contracthost.DurableSnapshotSpec{ + SnapshotID: "snap-clean", + MachineID: "source", + ImageID: "image-1", + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))}, + {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate", SHA256Hex: mustSHA256Hex(t, []byte("state"))}, + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system", SHA256Hex: mustSHA256Hex(t, []byte("disk"))}, + }, + }, + }) + if err != nil { + t.Fatalf("RestoreSnapshot returned error: %v", err) + } + + stagingDir := filepath.Join(cfg.SnapshotsDir, "snap-clean", "restores", "restored-clean") + if _, statErr := os.Stat(stagingDir); !os.IsNotExist(statErr) { + t.Fatalf("restore staging dir should be cleaned up, stat err = %v", statErr) + } +} + +func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/memory": []byte("mem"), + }) + defer server.Close() + + _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-fail-clean", contracthost.RestoreSnapshotRequest{ + MachineID: "restored-fail-clean", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + Snapshot: contracthost.DurableSnapshotSpec{ + SnapshotID: "snap-fail-clean", + MachineID: "source", + ImageID: "image-1", + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))}, + {ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/missing"}, + }, + }, + }) + if err == nil || !strings.Contains(err.Error(), "download durable snapshot artifacts") { + t.Fatalf("RestoreSnapshot error = %v, want durable artifact download failure", err) + } + + stagingDir := filepath.Join(cfg.SnapshotsDir, "snap-fail-clean", "restores", "restored-fail-clean") + if _, statErr := os.Stat(stagingDir); !os.IsNotExist(statErr) { + t.Fatalf("restore staging dir should be cleaned up after download failure, stat err = %v", statErr) + } +} + func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -423,6 +869,26 @@ func waitPublishedPortResult(t *testing.T, ch <-chan publishedPortResult) publis } } +func exhaustedMachineRelayRecords() []model.MachineRecord { + count := int(maxMachineSSHRelayPort-minMachineSSHRelayPort) + 1 + machines := make([]model.MachineRecord, 0, count) + for i := 0; i < count; i++ { + machines = append(machines, model.MachineRecord{ + ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)), + Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i)), + Phase: contracthost.MachinePhaseRunning, + }) + } + return machines +} + +func mustSHA256Hex(t *testing.T, payload []byte) string { + t.Helper() + + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]) +} + func assertOperationCount(t *testing.T, store hoststore.Store, want int) { t.Helper() diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index d8329ac..bb77800 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -7,6 +7,8 @@ import ( "os" "os/exec" "path/filepath" + "sort" + "strconv" "strings" "time" @@ -108,6 +110,22 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach return nil, fmt.Errorf("copy system disk: %w", err) } diskPaths = append(diskPaths, systemDiskTarget) + for i, volumeID := range record.UserVolumeIDs { + volume, err := d.store.GetVolume(ctx, volumeID) + if err != nil { + _ = d.runtime.Resume(ctx, runtimeState) + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("get attached volume %q: %w", volumeID, err) + } + driveID := fmt.Sprintf("user-%d", i) + targetPath := filepath.Join(snapshotDir, driveID+".img") + if err := cowCopyFile(volume.Path, targetPath); err != nil { + _ = d.runtime.Resume(ctx, runtimeState) + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("copy attached volume %q: %w", volumeID, err) + } + diskPaths = append(diskPaths, targetPath) + } // Resume the source VM if err := d.runtime.Resume(ctx, runtimeState); err != nil { @@ -132,6 +150,12 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach return nil, fmt.Errorf("move vmstate file: %w", err) } + artifacts, err := buildSnapshotArtifacts(dstMemPath, dstStatePath, diskPaths) + if err != nil { + _ = os.RemoveAll(snapshotDir) + return nil, fmt.Errorf("build snapshot artifacts: %w", err) + } + now := time.Now().UTC() snapshotRecord := model.SnapshotRecord{ ID: snapshotID, @@ -140,6 +164,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach MemFilePath: dstMemPath, StateFilePath: dstStatePath, DiskPaths: diskPaths, + Artifacts: artifacts, SourceRuntimeHost: record.RuntimeHost, SourceTapDevice: record.TapDevice, CreatedAt: now, @@ -151,25 +176,60 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach clearOperation = true return &contracthost.CreateSnapshotResponse{ - Snapshot: snapshotToContract(snapshotRecord), + Snapshot: snapshotToContract(snapshotRecord), + Artifacts: snapshotArtifactsToContract(snapshotRecord.Artifacts), }, nil } +func (d *Daemon) UploadSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error) { + snapshot, err := d.store.GetSnapshot(ctx, snapshotID) + if err != nil { + return nil, err + } + artifactIndex := make(map[string]model.SnapshotArtifactRecord, len(snapshot.Artifacts)) + for _, artifact := range snapshot.Artifacts { + artifactIndex[artifact.ID] = artifact + } + + response := &contracthost.UploadSnapshotResponse{ + Artifacts: make([]contracthost.UploadedSnapshotArtifact, 0, len(req.Artifacts)), + } + for _, upload := range req.Artifacts { + artifact, ok := artifactIndex[upload.ArtifactID] + if !ok { + return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID) + } + completedParts, err := uploadSnapshotArtifact(ctx, artifact.LocalPath, upload.Parts) + if err != nil { + return nil, fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err) + } + response.Artifacts = append(response.Artifacts, contracthost.UploadedSnapshotArtifact{ + ArtifactID: upload.ArtifactID, + CompletedParts: completedParts, + }) + } + + return response, nil +} + func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) { if err := validateMachineID(req.MachineID); err != nil { return nil, err } + if req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID { + return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID) + } + if err := validateArtifactRef(req.Artifact); err != nil { + return nil, err + } unlock := d.lockMachine(req.MachineID) defer unlock() - snap, err := d.store.GetSnapshot(ctx, snapshotID) - if err != nil { - return nil, err - } - if _, err := d.store.GetMachine(ctx, req.MachineID); err == nil { return nil, fmt.Errorf("machine %q already exists", req.MachineID) + } else if err != nil && err != store.ErrNotFound { + return nil, err } if err := d.store.UpsertOperation(ctx, model.OperationRecord{ @@ -188,55 +248,90 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } }() - sourceMachine, err := d.store.GetMachine(ctx, snap.MachineID) - switch { - case err == nil && sourceMachine.Phase == contracthost.MachinePhaseRunning: - clearOperation = true - return nil, fmt.Errorf("restore from snapshot %q while source machine %q is running is not supported yet", snapshotID, snap.MachineID) - case err != nil && err != store.ErrNotFound: - return nil, fmt.Errorf("get source machine for restore: %w", err) - } - usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) if err != nil { return nil, err } - restoreNetwork, err := restoreNetworkFromSnapshot(snap) + restoreNetwork, err := d.resolveRestoreNetwork(ctx, snapshotID, req.Snapshot) if err != nil { clearOperation = true return nil, err } if networkAllocationInUse(restoreNetwork, usedNetworks) { clearOperation = true - return nil, fmt.Errorf("snapshot %q restore network %q (%s) is already in use", snapshotID, restoreNetwork.TapName, restoreNetwork.GuestIP()) + return nil, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName) } - artifact, err := d.store.GetArtifact(ctx, snap.Artifact) + artifact, err := d.ensureArtifact(ctx, req.Artifact) if err != nil { - return nil, fmt.Errorf("get artifact for restore: %w", err) + return nil, fmt.Errorf("ensure artifact for restore: %w", err) } + stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID)) + restoredArtifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts) + if err != nil { + _ = os.RemoveAll(stagingDir) + clearOperation = true + return nil, fmt.Errorf("download durable snapshot artifacts: %w", err) + } + defer func() { _ = os.RemoveAll(stagingDir) }() + // COW-copy system disk from snapshot to new machine's disk dir. newSystemDiskPath := d.systemVolumePath(req.MachineID) if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil { return nil, fmt.Errorf("create machine disk dir: %w", err) } - if len(snap.DiskPaths) < 1 { + systemDiskPath, ok := restoredArtifacts["system.img"] + if !ok { clearOperation = true - return nil, fmt.Errorf("snapshot %q has no disk paths", snapshotID) + return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID) } - if err := cowCopyFile(snap.DiskPaths[0], newSystemDiskPath); err != nil { + memoryArtifact, ok := restoredArtifacts["memory.bin"] + if !ok { + clearOperation = true + return nil, fmt.Errorf("snapshot %q is missing memory artifact", snapshotID) + } + vmstateArtifact, ok := restoredArtifacts["vmstate.bin"] + if !ok { + clearOperation = true + return nil, fmt.Errorf("snapshot %q is missing vmstate artifact", snapshotID) + } + if err := cowCopyFile(systemDiskPath.LocalPath, newSystemDiskPath); err != nil { clearOperation = true return nil, fmt.Errorf("copy system disk for restore: %w", err) } + type restoredUserVolume struct { + ID contracthost.VolumeID + Path string + DriveID string + } + restoredUserVolumes := make([]restoredUserVolume, 0) + restoredDrivePaths := make(map[string]string) + for _, restored := range orderedRestoredUserDiskArtifacts(restoredArtifacts) { + name := restored.Artifact.Name + driveID := strings.TrimSuffix(name, filepath.Ext(name)) + volumeID := contracthost.VolumeID(fmt.Sprintf("%s-%s", req.MachineID, driveID)) + volumePath := filepath.Join(d.config.MachineDisksDir, string(req.MachineID), name) + if err := cowCopyFile(restored.LocalPath, volumePath); err != nil { + clearOperation = true + return nil, fmt.Errorf("copy restored drive %q: %w", driveID, err) + } + restoredUserVolumes = append(restoredUserVolumes, restoredUserVolume{ + ID: volumeID, + Path: volumePath, + DriveID: driveID, + }) + restoredDrivePaths[driveID] = volumePath + } + loadSpec := firecracker.SnapshotLoadSpec{ ID: firecracker.MachineID(req.MachineID), - SnapshotPath: snap.StateFilePath, - MemFilePath: snap.MemFilePath, + SnapshotPath: vmstateArtifact.LocalPath, + MemFilePath: memoryArtifact.LocalPath, RootFSPath: newSystemDiskPath, KernelImagePath: artifact.KernelImagePath, - DiskPaths: map[string]string{}, + DiskPaths: restoredDrivePaths, Network: &restoreNetwork, } @@ -275,18 +370,44 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn ID: systemVolumeID, Kind: contracthost.VolumeKindSystem, AttachedMachineID: machineIDPtr(req.MachineID), - SourceArtifact: &snap.Artifact, + SourceArtifact: &req.Artifact, Pool: model.StoragePoolMachineDisks, Path: newSystemDiskPath, CreatedAt: now, }); err != nil { - return nil, err + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("create system volume record for restore: %w", err) + } + restoredUserVolumeIDs := make([]contracthost.VolumeID, 0, len(restoredUserVolumes)) + for _, volume := range restoredUserVolumes { + if err := d.store.CreateVolume(ctx, model.VolumeRecord{ + ID: volume.ID, + Kind: contracthost.VolumeKindUser, + AttachedMachineID: machineIDPtr(req.MachineID), + SourceArtifact: &req.Artifact, + Pool: model.StoragePoolMachineDisks, + Path: volume.Path, + CreatedAt: now, + }); err != nil { + for _, restoredVolumeID := range restoredUserVolumeIDs { + _ = d.store.DeleteVolume(context.Background(), restoredVolumeID) + } + _ = d.store.DeleteVolume(context.Background(), systemVolumeID) + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("create restored user volume record %q: %w", volume.ID, err) + } + restoredUserVolumeIDs = append(restoredUserVolumeIDs, volume.ID) } machineRecord := model.MachineRecord{ ID: req.MachineID, - Artifact: snap.Artifact, + Artifact: req.Artifact, SystemVolumeID: systemVolumeID, + UserVolumeIDs: restoredUserVolumeIDs, RuntimeHost: machineState.RuntimeHost, TapDevice: machineState.TapName, Ports: defaultMachinePorts(), @@ -306,7 +427,14 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn d.relayAllocMu.Unlock() if err != nil { d.stopMachineRelays(machineRecord.ID) - return nil, err + for _, restoredVolumeID := range restoredUserVolumeIDs { + _ = d.store.DeleteVolume(context.Background(), restoredVolumeID) + } + _ = d.store.DeleteVolume(context.Background(), systemVolumeID) + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true + return nil, fmt.Errorf("allocate relay ports for restored machine: %w", err) } machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) startedRelays := true @@ -316,6 +444,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } }() if err := d.store.CreateMachine(ctx, machineRecord); err != nil { + for _, restoredVolumeID := range restoredUserVolumeIDs { + _ = d.store.DeleteVolume(context.Background(), restoredVolumeID) + } + _ = d.store.DeleteVolume(context.Background(), systemVolumeID) + _ = d.runtime.Delete(ctx, *machineState) + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + clearOperation = true return nil, err } @@ -360,12 +495,89 @@ func (d *Daemon) DeleteSnapshotByID(ctx context.Context, snapshotID contracthost func snapshotToContract(record model.SnapshotRecord) contracthost.Snapshot { return contracthost.Snapshot{ - ID: record.ID, - MachineID: record.MachineID, - CreatedAt: record.CreatedAt, + ID: record.ID, + MachineID: record.MachineID, + SourceRuntimeHost: record.SourceRuntimeHost, + SourceTapDevice: record.SourceTapDevice, + CreatedAt: record.CreatedAt, } } +func snapshotArtifactsToContract(artifacts []model.SnapshotArtifactRecord) []contracthost.SnapshotArtifact { + converted := make([]contracthost.SnapshotArtifact, 0, len(artifacts)) + for _, artifact := range artifacts { + converted = append(converted, contracthost.SnapshotArtifact{ + ID: artifact.ID, + Kind: artifact.Kind, + Name: artifact.Name, + SizeBytes: artifact.SizeBytes, + SHA256Hex: artifact.SHA256Hex, + }) + } + return converted +} + +func orderedRestoredUserDiskArtifacts(artifacts map[string]restoredSnapshotArtifact) []restoredSnapshotArtifact { + ordered := make([]restoredSnapshotArtifact, 0, len(artifacts)) + for name, artifact := range artifacts { + if !strings.HasPrefix(name, "user-") || filepath.Ext(name) != ".img" { + continue + } + ordered = append(ordered, artifact) + } + sort.Slice(ordered, func(i, j int) bool { + iIdx, iOK := restoredUserDiskIndex(ordered[i].Artifact.Name) + jIdx, jOK := restoredUserDiskIndex(ordered[j].Artifact.Name) + switch { + case iOK && jOK && iIdx != jIdx: + return iIdx < jIdx + case iOK != jOK: + return iOK + default: + return ordered[i].Artifact.Name < ordered[j].Artifact.Name + } + }) + return ordered +} + +func restoredUserDiskIndex(name string) (int, bool) { + if !strings.HasPrefix(name, "user-") || filepath.Ext(name) != ".img" { + return 0, false + } + value := strings.TrimSuffix(strings.TrimPrefix(name, "user-"), filepath.Ext(name)) + index, err := strconv.Atoi(value) + if err != nil { + return 0, false + } + return index, true +} + +func (d *Daemon) resolveRestoreNetwork(ctx context.Context, snapshotID contracthost.SnapshotID, spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) { + if network, err := restoreNetworkFromDurableSpec(spec); err == nil { + return network, nil + } + + snapshot, err := d.store.GetSnapshot(ctx, snapshotID) + if err == nil { + return restoreNetworkFromSnapshot(snapshot) + } + if err != store.ErrNotFound { + return firecracker.NetworkAllocation{}, err + } + return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snapshotID) +} + +func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) { + if strings.TrimSpace(spec.SourceRuntimeHost) == "" || strings.TrimSpace(spec.SourceTapDevice) == "" { + return firecracker.NetworkAllocation{}, fmt.Errorf("durable snapshot spec is missing restore network metadata") + } + network, err := firecracker.AllocationFromGuestIP(spec.SourceRuntimeHost, spec.SourceTapDevice) + if err != nil { + return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct durable snapshot %q network: %w", spec.SnapshotID, err) + } + return network, nil +} + func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) { if snap == nil { return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required") diff --git a/internal/daemon/snapshot_transfer.go b/internal/daemon/snapshot_transfer.go new file mode 100644 index 0000000..6656b4d --- /dev/null +++ b/internal/daemon/snapshot_transfer.go @@ -0,0 +1,194 @@ +package daemon + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +type restoredSnapshotArtifact struct { + Artifact contracthost.SnapshotArtifact + LocalPath string +} + +func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) { + artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2) + + memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath) + if err != nil { + return nil, err + } + artifacts = append(artifacts, memoryArtifact) + + vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath) + if err != nil { + return nil, err + } + artifacts = append(artifacts, vmstateArtifact) + + for _, diskPath := range diskPaths { + base := filepath.Base(diskPath) + diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath) + if err != nil { + return nil, err + } + artifacts = append(artifacts, diskArtifact) + } + + sort.Slice(artifacts, func(i, j int) bool { + return artifacts[i].ID < artifacts[j].ID + }) + return artifacts, nil +} + +func snapshotArtifactRecord(id string, kind contracthost.SnapshotArtifactKind, name, path string) (model.SnapshotArtifactRecord, error) { + size, err := fileSize(path) + if err != nil { + return model.SnapshotArtifactRecord{}, err + } + sum, err := sha256File(path) + if err != nil { + return model.SnapshotArtifactRecord{}, err + } + return model.SnapshotArtifactRecord{ + ID: id, + Kind: kind, + Name: name, + LocalPath: path, + SizeBytes: size, + SHA256Hex: sum, + }, nil +} + +func sha256File(path string) (string, error) { + file, err := os.Open(path) + if err != nil { + return "", fmt.Errorf("open %q for sha256: %w", path, err) + } + defer func() { _ = file.Close() }() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", fmt.Errorf("hash %q: %w", path, err) + } + return hex.EncodeToString(hash.Sum(nil)), nil +} + +func uploadSnapshotArtifact(ctx context.Context, localPath string, parts []contracthost.SnapshotUploadPart) ([]contracthost.UploadedSnapshotPart, error) { + if len(parts) == 0 { + return nil, fmt.Errorf("upload session has no parts") + } + + file, err := os.Open(localPath) + if err != nil { + return nil, fmt.Errorf("open artifact %q: %w", localPath, err) + } + defer func() { _ = file.Close() }() + + client := &http.Client{} + completed := make([]contracthost.UploadedSnapshotPart, 0, len(parts)) + for _, part := range parts { + reader := io.NewSectionReader(file, part.OffsetBytes, part.SizeBytes) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, part.UploadURL, io.NopCloser(reader)) + if err != nil { + return nil, fmt.Errorf("build upload part %d: %w", part.PartNumber, err) + } + req.ContentLength = part.SizeBytes + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("upload part %d: %w", part.PartNumber, err) + } + _ = resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("upload part %d returned %d", part.PartNumber, resp.StatusCode) + } + etag := strings.TrimSpace(resp.Header.Get("ETag")) + if etag == "" { + return nil, fmt.Errorf("upload part %d returned empty etag", part.PartNumber) + } + completed = append(completed, contracthost.UploadedSnapshotPart{ + PartNumber: part.PartNumber, + ETag: etag, + }) + } + sort.Slice(completed, func(i, j int) bool { + return completed[i].PartNumber < completed[j].PartNumber + }) + return completed, nil +} + +func downloadDurableSnapshotArtifacts(ctx context.Context, root string, artifacts []contracthost.SnapshotArtifact) (map[string]restoredSnapshotArtifact, error) { + if len(artifacts) == 0 { + return nil, fmt.Errorf("restore snapshot is missing artifacts") + } + if err := os.MkdirAll(root, 0o755); err != nil { + return nil, fmt.Errorf("create restore staging dir %q: %w", root, err) + } + + client := &http.Client{} + restored := make(map[string]restoredSnapshotArtifact, len(artifacts)) + for _, artifact := range artifacts { + if strings.TrimSpace(artifact.DownloadURL) == "" { + return nil, fmt.Errorf("artifact %q is missing download url", artifact.ID) + } + localPath := filepath.Join(root, artifact.Name) + if err := downloadSnapshotArtifact(ctx, client, artifact.DownloadURL, localPath); err != nil { + return nil, err + } + if expectedSHA := strings.TrimSpace(artifact.SHA256Hex); expectedSHA != "" { + actualSHA, err := sha256File(localPath) + if err != nil { + return nil, err + } + if !strings.EqualFold(actualSHA, expectedSHA) { + _ = os.Remove(localPath) + return nil, fmt.Errorf("restore artifact %q sha256 mismatch: got %s want %s", artifact.Name, actualSHA, expectedSHA) + } + } + restored[artifact.Name] = restoredSnapshotArtifact{ + Artifact: artifact, + LocalPath: localPath, + } + } + return restored, nil +} + +func downloadSnapshotArtifact(ctx context.Context, client *http.Client, sourceURL, targetPath string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil) + if err != nil { + return fmt.Errorf("build restore download request: %w", err) + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("download durable snapshot artifact: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("download durable snapshot artifact returned %d", resp.StatusCode) + } + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + return fmt.Errorf("create restore artifact dir %q: %w", filepath.Dir(targetPath), err) + } + out, err := os.Create(targetPath) + if err != nil { + return fmt.Errorf("create restore artifact %q: %w", targetPath, err) + } + defer func() { _ = out.Close() }() + + if _, err := io.Copy(out, resp.Body); err != nil { + return fmt.Errorf("write restore artifact %q: %w", targetPath, err) + } + return nil +} diff --git a/internal/daemon/snapshot_transfer_test.go b/internal/daemon/snapshot_transfer_test.go new file mode 100644 index 0000000..8f6c2e3 --- /dev/null +++ b/internal/daemon/snapshot_transfer_test.go @@ -0,0 +1,60 @@ +package daemon + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func TestUploadSnapshotArtifactRejectsEmptyETag(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + t.Fatalf("unexpected method %q", r.Method) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + artifactPath := filepath.Join(t.TempDir(), "artifact.bin") + if err := os.WriteFile(artifactPath, []byte("payload"), 0o644); err != nil { + t.Fatalf("write artifact: %v", err) + } + + _, err := uploadSnapshotArtifact(context.Background(), artifactPath, []contracthost.SnapshotUploadPart{{ + PartNumber: 1, + OffsetBytes: 0, + SizeBytes: int64(len("payload")), + UploadURL: server.URL, + }}) + if err == nil || !strings.Contains(err.Error(), "empty etag") { + t.Fatalf("uploadSnapshotArtifact error = %v, want empty etag failure", err) + } +} + +func TestDownloadDurableSnapshotArtifactsRejectsSHA256Mismatch(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("payload")) + })) + defer server.Close() + + root := t.TempDir() + _, err := downloadDurableSnapshotArtifacts(context.Background(), root, []contracthost.SnapshotArtifact{{ + ID: "memory", + Kind: contracthost.SnapshotArtifactKindMemory, + Name: "memory.bin", + DownloadURL: server.URL, + SHA256Hex: strings.Repeat("0", 64), + }}) + if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") { + t.Fatalf("downloadDurableSnapshotArtifacts error = %v, want sha256 mismatch", err) + } + if _, statErr := os.Stat(filepath.Join(root, "memory.bin")); !os.IsNotExist(statErr) { + t.Fatalf("corrupt artifact should be removed, stat err = %v", statErr) + } +} diff --git a/internal/daemon/storage_report.go b/internal/daemon/storage_report.go index bdf5817..8609696 100644 --- a/internal/daemon/storage_report.go +++ b/internal/daemon/storage_report.go @@ -52,19 +52,27 @@ func (d *Daemon) GetStorageReport(ctx context.Context) (*contracthost.GetStorage } } - machineUsage := make([]contracthost.MachineStorageUsage, 0, len(volumes)) + machineUsageByID := make(map[contracthost.MachineID]contracthost.MachineStorageUsage) for _, volume := range volumes { - if volume.AttachedMachineID == nil || volume.Kind != contracthost.VolumeKindSystem { + if volume.AttachedMachineID == nil { continue } bytes, err := fileSize(volume.Path) if err != nil { return nil, err } - machineUsage = append(machineUsage, contracthost.MachineStorageUsage{ - MachineID: *volume.AttachedMachineID, - SystemBytes: bytes, - }) + usage := machineUsageByID[*volume.AttachedMachineID] + usage.MachineID = *volume.AttachedMachineID + if volume.Kind == contracthost.VolumeKindSystem { + usage.SystemBytes += bytes + } else { + usage.UserBytes += bytes + } + machineUsageByID[*volume.AttachedMachineID] = usage + } + machineUsage := make([]contracthost.MachineStorageUsage, 0, len(machineUsageByID)) + for _, usage := range machineUsageByID { + machineUsage = append(machineUsage, usage) } snapshotUsage := make([]contracthost.SnapshotStorageUsage, 0, len(snapshots)) diff --git a/internal/firecracker/api.go b/internal/firecracker/api.go index d3d96b1..53b1ddb 100644 --- a/internal/firecracker/api.go +++ b/internal/firecracker/api.go @@ -27,10 +27,12 @@ type bootSourceRequest struct { } type driveRequest struct { - DriveID string `json:"drive_id"` - IsReadOnly bool `json:"is_read_only"` - IsRootDevice bool `json:"is_root_device"` - PathOnHost string `json:"path_on_host"` + DriveID string `json:"drive_id"` + IsReadOnly bool `json:"is_read_only"` + IsRootDevice bool `json:"is_root_device"` + PathOnHost string `json:"path_on_host"` + CacheType DriveCacheType `json:"cache_type,omitempty"` + IOEngine DriveIOEngine `json:"io_engine,omitempty"` } type entropyRequest struct{} @@ -58,6 +60,13 @@ type networkInterfaceRequest struct { IfaceID string `json:"iface_id"` } +type mmdsConfigRequest struct { + IPv4Address string `json:"ipv4_address,omitempty"` + NetworkInterfaces []string `json:"network_interfaces"` + Version MMDSVersion `json:"version,omitempty"` + IMDSCompat bool `json:"imds_compat,omitempty"` +} + type serialRequest struct { SerialOutPath string `json:"serial_out_path"` } @@ -127,6 +136,24 @@ func (c *apiClient) PutNetworkInterface(ctx context.Context, network NetworkAllo return c.do(ctx, http.MethodPut, endpoint, body, nil, http.StatusNoContent) } +func (c *apiClient) PutMMDSConfig(ctx context.Context, spec MMDSSpec) error { + body := mmdsConfigRequest{ + IPv4Address: strings.TrimSpace(spec.IPv4Address), + NetworkInterfaces: append([]string(nil), spec.NetworkInterfaces...), + Version: spec.Version, + IMDSCompat: spec.IMDSCompat, + } + return c.do(ctx, http.MethodPut, "/mmds/config", body, nil, http.StatusNoContent) +} + +func (c *apiClient) PutMMDS(ctx context.Context, data any) error { + return c.do(ctx, http.MethodPut, "/mmds", data, nil, http.StatusNoContent) +} + +func (c *apiClient) PatchMMDS(ctx context.Context, data any) error { + return c.do(ctx, http.MethodPatch, "/mmds", data, nil, http.StatusNoContent) +} + func (c *apiClient) PutSerial(ctx context.Context, serialOutPath string) error { return c.do( ctx, diff --git a/internal/firecracker/configure_test.go b/internal/firecracker/configure_test.go index 26904f3..7fe00a5 100644 --- a/internal/firecracker/configure_test.go +++ b/internal/firecracker/configure_test.go @@ -83,6 +83,91 @@ func TestConfigureMachineEnablesEntropyAndSerialBeforeStart(t *testing.T) { } } +func TestConfigureMachineConfiguresMMDSBeforeStart(t *testing.T) { + var requests []capturedRequest + + socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + requests = append(requests, capturedRequest{ + Method: r.Method, + Path: r.URL.Path, + Body: string(body), + }) + w.WriteHeader(http.StatusNoContent) + }) + defer shutdown() + + client := newAPIClient(socketPath) + spec := MachineSpec{ + ID: "vm-2", + VCPUs: 1, + MemoryMiB: 512, + KernelImagePath: "/kernel", + RootFSPath: "/rootfs", + RootDrive: DriveSpec{ + ID: "root_drive", + Path: "/rootfs", + CacheType: DriveCacheTypeUnsafe, + IOEngine: DriveIOEngineSync, + }, + MMDS: &MMDSSpec{ + NetworkInterfaces: []string{"net0"}, + Version: MMDSVersionV2, + IPv4Address: "169.254.169.254", + Data: map[string]any{ + "latest": map[string]any{ + "meta-data": map[string]any{ + "microagent": map[string]any{"hostname": "vm-2"}, + }, + }, + }, + }, + } + paths := machinePaths{JailedSerialLogPath: "/logs/serial.log"} + network := NetworkAllocation{ + InterfaceID: defaultInterfaceID, + TapName: "fctap0", + GuestMAC: "06:00:ac:10:00:02", + } + + if err := configureMachine(context.Background(), client, paths, spec, network); err != nil { + t.Fatalf("configure machine: %v", err) + } + + gotPaths := make([]string, 0, len(requests)) + for _, request := range requests { + gotPaths = append(gotPaths, request.Path) + } + wantPaths := []string{ + "/machine-config", + "/boot-source", + "/drives/root_drive", + "/network-interfaces/net0", + "/mmds/config", + "/mmds", + "/entropy", + "/serial", + "/actions", + } + if len(gotPaths) != len(wantPaths) { + t.Fatalf("request count mismatch: got %d want %d (%v)", len(gotPaths), len(wantPaths), gotPaths) + } + for i := range wantPaths { + if gotPaths[i] != wantPaths[i] { + t.Fatalf("request %d mismatch: got %q want %q", i, gotPaths[i], wantPaths[i]) + } + } + if requests[2].Body != "{\"drive_id\":\"root_drive\",\"is_read_only\":false,\"is_root_device\":true,\"path_on_host\":\"/rootfs\",\"cache_type\":\"Unsafe\",\"io_engine\":\"Sync\"}" { + t.Fatalf("root drive body mismatch: got %q", requests[2].Body) + } + if requests[4].Body != "{\"ipv4_address\":\"169.254.169.254\",\"network_interfaces\":[\"net0\"],\"version\":\"V2\"}" { + t.Fatalf("mmds config body mismatch: got %q", requests[4].Body) + } +} + func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) { t.Helper() diff --git a/internal/firecracker/launch.go b/internal/firecracker/launch.go index 3b57c87..a01a743 100644 --- a/internal/firecracker/launch.go +++ b/internal/firecracker/launch.go @@ -39,6 +39,16 @@ func configureMachine(ctx context.Context, client *apiClient, paths machinePaths if err := client.PutNetworkInterface(ctx, network); err != nil { return fmt.Errorf("put network interface: %w", err) } + if spec.MMDS != nil { + if err := client.PutMMDSConfig(ctx, *spec.MMDS); err != nil { + return fmt.Errorf("put mmds config: %w", err) + } + if spec.MMDS.Data != nil { + if err := client.PutMMDS(ctx, spec.MMDS.Data); err != nil { + return fmt.Errorf("put mmds payload: %w", err) + } + } + } if err := client.PutEntropy(ctx); err != nil { return fmt.Errorf("put entropy device: %w", err) } @@ -97,12 +107,14 @@ func stageMachineFiles(spec MachineSpec, paths machinePaths) (MachineSpec, error rootFSPath, err := stagedFileName(spec.RootFSPath) if err != nil { - return MachineSpec{}, fmt.Errorf("rootfs path: %w", err) + return MachineSpec{}, fmt.Errorf("root drive path: %w", err) } if err := linkMachineFile(spec.RootFSPath, filepath.Join(paths.ChrootRootDir, rootFSPath)); err != nil { - return MachineSpec{}, fmt.Errorf("link rootfs into jail: %w", err) + return MachineSpec{}, fmt.Errorf("link root drive into jail: %w", err) } staged.RootFSPath = rootFSPath + staged.RootDrive = spec.rootDrive() + staged.RootDrive.Path = rootFSPath staged.Drives = make([]DriveSpec, len(spec.Drives)) for i, drive := range spec.Drives { @@ -174,6 +186,8 @@ func additionalDriveRequests(spec MachineSpec) []driveRequest { IsReadOnly: drive.ReadOnly, IsRootDevice: false, PathOnHost: drive.Path, + CacheType: drive.CacheType, + IOEngine: drive.IOEngine, }) } return requests @@ -249,11 +263,14 @@ func linkMachineFile(source string, target string) error { } func rootDriveRequest(spec MachineSpec) driveRequest { + root := spec.rootDrive() return driveRequest{ - DriveID: defaultRootDriveID, - IsReadOnly: false, + DriveID: root.ID, + IsReadOnly: root.ReadOnly, IsRootDevice: true, - PathOnHost: spec.RootFSPath, + PathOnHost: root.Path, + CacheType: root.CacheType, + IOEngine: root.IOEngine, } } diff --git a/internal/firecracker/spec.go b/internal/firecracker/spec.go index ba57552..6118668 100644 --- a/internal/firecracker/spec.go +++ b/internal/firecracker/spec.go @@ -16,16 +16,50 @@ type MachineSpec struct { MemoryMiB int64 KernelImagePath string RootFSPath string + RootDrive DriveSpec KernelArgs string Drives []DriveSpec + MMDS *MMDSSpec Vsock *VsockSpec } // DriveSpec describes an additional guest block device. type DriveSpec struct { - ID string - Path string - ReadOnly bool + ID string + Path string + ReadOnly bool + CacheType DriveCacheType + IOEngine DriveIOEngine +} + +type DriveCacheType string + +const ( + DriveCacheTypeUnsafe DriveCacheType = "Unsafe" + DriveCacheTypeWriteback DriveCacheType = "Writeback" +) + +type DriveIOEngine string + +const ( + DriveIOEngineSync DriveIOEngine = "Sync" + DriveIOEngineAsync DriveIOEngine = "Async" +) + +type MMDSVersion string + +const ( + MMDSVersionV1 MMDSVersion = "V1" + MMDSVersionV2 MMDSVersion = "V2" +) + +// MMDSSpec describes the MMDS network configuration and initial payload. +type MMDSSpec struct { + NetworkInterfaces []string + Version MMDSVersion + IPv4Address string + IMDSCompat bool + Data any } // VsockSpec describes a single host-guest vsock device. @@ -49,17 +83,22 @@ func (s MachineSpec) Validate() error { if strings.TrimSpace(s.KernelImagePath) == "" { return fmt.Errorf("machine kernel image path is required") } - if strings.TrimSpace(s.RootFSPath) == "" { - return fmt.Errorf("machine rootfs path is required") - } if filepath.Base(strings.TrimSpace(string(s.ID))) != strings.TrimSpace(string(s.ID)) { return fmt.Errorf("machine id %q must not contain path separators", s.ID) } + if err := s.rootDrive().Validate(); err != nil { + return fmt.Errorf("root drive: %w", err) + } for i, drive := range s.Drives { if err := drive.Validate(); err != nil { return fmt.Errorf("drive %d: %w", i, err) } } + if s.MMDS != nil { + if err := s.MMDS.Validate(); err != nil { + return fmt.Errorf("mmds: %w", err) + } + } if s.Vsock != nil { if err := s.Vsock.Validate(); err != nil { return fmt.Errorf("vsock: %w", err) @@ -70,11 +109,39 @@ func (s MachineSpec) Validate() error { // Validate reports whether the drive specification is usable. func (d DriveSpec) Validate() error { + if strings.TrimSpace(d.Path) == "" { + return fmt.Errorf("drive path is required") + } if strings.TrimSpace(d.ID) == "" { return fmt.Errorf("drive id is required") } - if strings.TrimSpace(d.Path) == "" { - return fmt.Errorf("drive path is required") + switch d.CacheType { + case "", DriveCacheTypeUnsafe, DriveCacheTypeWriteback: + default: + return fmt.Errorf("unsupported drive cache type %q", d.CacheType) + } + switch d.IOEngine { + case "", DriveIOEngineSync, DriveIOEngineAsync: + default: + return fmt.Errorf("unsupported drive io engine %q", d.IOEngine) + } + return nil +} + +// Validate reports whether the MMDS configuration is usable. +func (m MMDSSpec) Validate() error { + if len(m.NetworkInterfaces) == 0 { + return fmt.Errorf("mmds network interfaces are required") + } + switch m.Version { + case "", MMDSVersionV1, MMDSVersionV2: + default: + return fmt.Errorf("unsupported mmds version %q", m.Version) + } + for i, iface := range m.NetworkInterfaces { + if strings.TrimSpace(iface) == "" { + return fmt.Errorf("mmds network_interfaces[%d] is required", i) + } } return nil } @@ -92,3 +159,14 @@ func (v VsockSpec) Validate() error { } return nil } + +func (s MachineSpec) rootDrive() DriveSpec { + root := s.RootDrive + if strings.TrimSpace(root.ID) == "" { + root.ID = defaultRootDriveID + } + if strings.TrimSpace(root.Path) == "" { + root.Path = s.RootFSPath + } + return root +} diff --git a/internal/httpapi/handlers.go b/internal/httpapi/handlers.go index 49c71db..8ac3498 100644 --- a/internal/httpapi/handlers.go +++ b/internal/httpapi/handlers.go @@ -20,6 +20,7 @@ type Service interface { Health(context.Context) (*contracthost.HealthResponse, error) GetStorageReport(context.Context) (*contracthost.GetStorageReportResponse, error) CreateSnapshot(context.Context, contracthost.MachineID, contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) + UploadSnapshot(context.Context, contracthost.SnapshotID, contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error) ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error) DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error @@ -278,6 +279,25 @@ func (h *Handler) handleSnapshot(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 2 && parts[1] == "upload" { + if r.Method != http.MethodPost { + writeMethodNotAllowed(w) + return + } + var req contracthost.UploadSnapshotRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + response, err := h.service.UploadSnapshot(r.Context(), snapshotID, req) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + return + } + writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) } diff --git a/internal/model/types.go b/internal/model/types.go index e93b625..ad4dd9f 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -29,6 +29,7 @@ type ArtifactRecord struct { type MachineRecord struct { ID contracthost.MachineID Artifact contracthost.ArtifactRef + GuestConfig *contracthost.GuestConfig SystemVolumeID contracthost.VolumeID UserVolumeIDs []contracthost.VolumeID RuntimeHost string @@ -71,11 +72,21 @@ type SnapshotRecord struct { MemFilePath string StateFilePath string DiskPaths []string + Artifacts []SnapshotArtifactRecord SourceRuntimeHost string SourceTapDevice string CreatedAt time.Time } +type SnapshotArtifactRecord struct { + ID string + Kind contracthost.SnapshotArtifactKind + Name string + LocalPath string + SizeBytes int64 + SHA256Hex string +} + type PublishedPortRecord struct { ID contracthost.PublishedPortID MachineID contracthost.MachineID