feat: firecracker mmds identity

This commit is contained in:
Harivansh Rathi 2026-04-10 00:53:47 +00:00
parent 500354cd9b
commit 3eb610b703
23 changed files with 1813 additions and 263 deletions

View file

@ -17,6 +17,7 @@ type Machine struct {
}
type GuestConfig struct {
Hostname string `json:"hostname,omitempty"`
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"`

View file

@ -4,10 +4,31 @@ import "time"
type SnapshotID string
type SnapshotArtifactKind string
const (
SnapshotArtifactKindMemory SnapshotArtifactKind = "memory"
SnapshotArtifactKindVMState SnapshotArtifactKind = "vmstate"
SnapshotArtifactKindDisk SnapshotArtifactKind = "disk"
SnapshotArtifactKindManifest SnapshotArtifactKind = "manifest"
)
type Snapshot struct {
ID SnapshotID `json:"id"`
MachineID MachineID `json:"machine_id"`
CreatedAt time.Time `json:"created_at"`
ID SnapshotID `json:"id"`
MachineID MachineID `json:"machine_id"`
SourceRuntimeHost string `json:"source_runtime_host,omitempty"`
SourceTapDevice string `json:"source_tap_device,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type SnapshotArtifact struct {
ID string `json:"id"`
Kind SnapshotArtifactKind `json:"kind"`
Name string `json:"name"`
SizeBytes int64 `json:"size_bytes"`
SHA256Hex string `json:"sha256_hex,omitempty"`
ObjectKey string `json:"object_key,omitempty"`
DownloadURL string `json:"download_url,omitempty"`
}
type CreateSnapshotRequest struct {
@ -15,7 +36,8 @@ type CreateSnapshotRequest struct {
}
type CreateSnapshotResponse struct {
Snapshot Snapshot `json:"snapshot"`
Snapshot Snapshot `json:"snapshot"`
Artifacts []SnapshotArtifact `json:"artifacts,omitempty"`
}
type GetSnapshotResponse struct {
@ -26,8 +48,52 @@ type ListSnapshotsResponse struct {
Snapshots []Snapshot `json:"snapshots"`
}
type SnapshotUploadPart struct {
PartNumber int32 `json:"part_number"`
OffsetBytes int64 `json:"offset_bytes"`
SizeBytes int64 `json:"size_bytes"`
UploadURL string `json:"upload_url"`
}
type SnapshotArtifactUploadSession struct {
ArtifactID string `json:"artifact_id"`
ObjectKey string `json:"object_key"`
UploadID string `json:"upload_id"`
Parts []SnapshotUploadPart `json:"parts"`
}
type UploadSnapshotRequest struct {
Artifacts []SnapshotArtifactUploadSession `json:"artifacts"`
}
type UploadedSnapshotPart struct {
PartNumber int32 `json:"part_number"`
ETag string `json:"etag"`
}
type UploadedSnapshotArtifact struct {
ArtifactID string `json:"artifact_id"`
CompletedParts []UploadedSnapshotPart `json:"completed_parts"`
}
type UploadSnapshotResponse struct {
Artifacts []UploadedSnapshotArtifact `json:"artifacts"`
}
type RestoreSnapshotRequest struct {
MachineID MachineID `json:"machine_id"`
MachineID MachineID `json:"machine_id"`
Artifact ArtifactRef `json:"artifact"`
Snapshot DurableSnapshotSpec `json:"snapshot"`
GuestConfig *GuestConfig `json:"guest_config,omitempty"`
}
type DurableSnapshotSpec struct {
SnapshotID SnapshotID `json:"snapshot_id"`
MachineID MachineID `json:"machine_id"`
ImageID string `json:"image_id"`
SourceRuntimeHost string `json:"source_runtime_host,omitempty"`
SourceTapDevice string `json:"source_tap_device,omitempty"`
Artifacts []SnapshotArtifact `json:"artifacts"`
}
type RestoreSnapshotResponse struct {

View file

@ -9,9 +9,10 @@ type VolumeID string
type VolumeKind string
const (
MachinePhaseRunning MachinePhase = "running"
MachinePhaseStopped MachinePhase = "stopped"
MachinePhaseFailed MachinePhase = "failed"
MachinePhaseStarting MachinePhase = "starting"
MachinePhaseRunning MachinePhase = "running"
MachinePhaseStopped MachinePhase = "stopped"
MachinePhaseFailed MachinePhase = "failed"
)
const (

View file

@ -66,7 +66,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
if err := os.MkdirAll(filepath.Dir(systemVolumePath), 0o755); err != nil {
return nil, fmt.Errorf("create system volume dir for %q: %w", req.MachineID, err)
}
if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil {
if err := cowCopyFile(artifact.RootFSPath, systemVolumePath); err != nil {
return nil, err
}
removeSystemVolumeOnFailure := true
@ -77,14 +77,8 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
_ = os.Remove(systemVolumePath)
_ = os.RemoveAll(filepath.Dir(systemVolumePath))
}()
if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
return nil, err
}
if err := injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil {
return nil, err
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath)
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig)
if err != nil {
return nil, err
}
@ -98,17 +92,6 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
return nil, err
}
ports := defaultMachinePorts()
if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
now := time.Now().UTC()
systemVolumeRecord := model.VolumeRecord{
ID: d.systemVolumeID(req.MachineID),
@ -143,44 +126,20 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
}
record := model.MachineRecord{
ID: req.MachineID,
Artifact: req.Artifact,
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: ports,
GuestSSHPublicKey: guestSSHPublicKey,
Phase: contracthost.MachinePhaseRunning,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
ID: req.MachineID,
Artifact: req.Artifact,
GuestConfig: cloneGuestConfig(guestConfig),
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseStarting,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
}
d.relayAllocMu.Lock()
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
var vncRelayPort uint16
if err == nil {
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
}
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelays(record.ID)
for _, volume := range userVolumes {
volume.AttachedMachineID = nil
_ = d.store.UpdateVolume(context.Background(), volume)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
startedRelays := true
defer func() {
if startedRelays {
d.stopMachineRelays(record.ID)
}
}()
if err := d.store.CreateMachine(ctx, record); err != nil {
for _, volume := range userVolumes {
volume.AttachedMachineID = nil
@ -192,12 +151,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
}
removeSystemVolumeOnFailure = false
startedRelays = false
clearOperation = true
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
}
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string) (firecracker.MachineSpec, error) {
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) {
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
for i, volume := range userVolumes {
drives = append(drives, firecracker.DriveSpec{
@ -207,14 +165,25 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
})
}
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
if err != nil {
return firecracker.MachineSpec{}, err
}
spec := firecracker.MachineSpec{
ID: firecracker.MachineID(machineID),
VCPUs: defaultGuestVCPUs,
MemoryMiB: defaultGuestMemoryMiB,
KernelImagePath: artifact.KernelImagePath,
RootFSPath: systemVolumePath,
KernelArgs: defaultGuestKernelArgs,
Drives: drives,
RootDrive: firecracker.DriveSpec{
ID: "root_drive",
Path: systemVolumePath,
CacheType: firecracker.DriveCacheTypeUnsafe,
IOEngine: firecracker.DriveIOEngineSync,
},
KernelArgs: defaultGuestKernelArgs,
Drives: drives,
MMDS: mmds,
}
if err := spec.Validate(); err != nil {
return firecracker.MachineSpec{}, err

View file

@ -43,6 +43,7 @@ type Daemon struct {
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error
readGuestSSHPublicKey func(context.Context, string) (string, error)
syncGuestFilesystem func(context.Context, string) error
locksMu sync.Mutex
machineLocks map[contracthost.MachineID]*sync.Mutex
@ -84,6 +85,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
}
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
return nil, err
}

View file

@ -141,7 +141,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
t.Fatalf("create machine: %v", err)
}
if response.Machine.Phase != contracthost.MachinePhaseRunning {
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
}
if response.Machine.RuntimeHost != "127.0.0.1" {
@ -169,29 +169,25 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
if err != nil {
t.Fatalf("read backend ssh public key: %v", err)
}
authorizedKeys, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/authorized_keys")
if err != nil {
t.Fatalf("read injected authorized_keys: %v", err)
if runtime.lastSpec.MMDS == nil {
t.Fatalf("expected MMDS configuration on machine spec")
}
if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 {
t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version)
}
payload, ok := runtime.lastSpec.MMDS.Data.(guestMetadataEnvelope)
if !ok {
t.Fatalf("mmds payload type mismatch: got %T", runtime.lastSpec.MMDS.Data)
}
if payload.Latest.MetaData.Hostname != "vm-1" {
t.Fatalf("mmds hostname mismatch: got %q", payload.Latest.MetaData.Hostname)
}
authorizedKeys := strings.Join(payload.Latest.MetaData.AuthorizedKeys, "\n")
if !strings.Contains(authorizedKeys, strings.TrimSpace(string(hostAuthorizedKeyBytes))) {
t.Fatalf("authorized_keys missing backend ssh key: %q", authorizedKeys)
t.Fatalf("mmds authorized_keys missing backend ssh key: %q", authorizedKeys)
}
if !strings.Contains(authorizedKeys, "daemon-test") {
t.Fatalf("authorized_keys missing request override key: %q", authorizedKeys)
}
machineName, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/machine-name")
if err != nil {
t.Fatalf("read injected machine-name: %v", err)
}
if machineName != "vm-1\n" {
t.Fatalf("machine-name mismatch: got %q want %q", machineName, "vm-1\n")
}
hosts, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/hosts")
if err != nil {
t.Fatalf("read injected hosts: %v", err)
}
if !strings.Contains(hosts, "127.0.1.1 vm-1") {
t.Fatalf("hosts missing machine identity: %q", hosts)
t.Fatalf("mmds authorized_keys missing request override key: %q", authorizedKeys)
}
artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact)
@ -214,6 +210,12 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
if machine.SystemVolumeID != "vm-1-system" {
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
}
if machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("stored machine phase mismatch: got %q", machine.Phase)
}
if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 {
t.Fatalf("stored guest config missing authorized keys: %#v", machine.GuestConfig)
}
operations, err := fileStore.ListOperations(context.Background())
if err != nil {
@ -224,6 +226,65 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
}
}
func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
var syncedHost string
hostDaemon.syncGuestFilesystem = func(_ context.Context, runtimeHost string) error {
syncedHost = runtimeHost
return nil
}
now := time.Now().UTC()
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "vm-stop",
SystemVolumeID: "vm-stop-system",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap-stop",
Phase: contracthost.MachinePhaseRunning,
PID: 1234,
SocketPath: filepath.Join(root, "runtime", "vm-stop.sock"),
Ports: defaultMachinePorts(),
CreatedAt: now,
StartedAt: &now,
}); err != nil {
t.Fatalf("create machine: %v", err)
}
if err := hostDaemon.StopMachine(context.Background(), "vm-stop"); err != nil {
t.Fatalf("stop machine: %v", err)
}
if syncedHost != "172.16.0.2" {
t.Fatalf("sync host mismatch: got %q want %q", syncedHost, "172.16.0.2")
}
if len(runtime.deleteCalls) != 1 {
t.Fatalf("runtime delete call count mismatch: got %d want 1", len(runtime.deleteCalls))
}
stopped, err := fileStore.GetMachine(context.Background(), "vm-stop")
if err != nil {
t.Fatalf("get stopped machine: %v", err)
}
if stopped.Phase != contracthost.MachinePhaseStopped {
t.Fatalf("machine phase mismatch: got %q want %q", stopped.Phase, contracthost.MachinePhaseStopped)
}
if stopped.RuntimeHost != "" {
t.Fatalf("runtime host should be cleared after stop, got %q", stopped.RuntimeHost)
}
}
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -249,7 +310,7 @@ func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
}
}
func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -257,7 +318,23 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() { _ = sshListener.Close() }()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() { _ = vncListener.Close() }()
startedAt := time.Unix(1700000099, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
ID: "restored",
Phase: firecracker.PhaseRunning,
PID: 1234,
RuntimeHost: "127.0.0.1",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored", "root", "run", "firecracker.sock"),
TapName: "fctap0",
StartedAt: &startedAt,
},
}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
@ -281,32 +358,13 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
t.Fatalf("put artifact: %v", err)
}
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "source",
Artifact: artifactRef,
SystemVolumeID: "source-system",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap0",
Phase: contracthost.MachinePhaseRunning,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create source machine: %v", err)
}
snapDisk := filepath.Join(root, "snapshots", "snap1", "system.img")
if err := os.MkdirAll(filepath.Dir(snapDisk), 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
t.Fatalf("write snapshot disk: %v", err)
}
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap1",
MachineID: "source",
Artifact: artifactRef,
MemFilePath: filepath.Join(root, "snapshots", "snap1", "memory.bin"),
StateFilePath: filepath.Join(root, "snapshots", "snap1", "vmstate.bin"),
DiskPaths: []string{snapDisk},
DiskPaths: []string{filepath.Join(root, "snapshots", "snap1", "system.img")},
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
CreatedAt: time.Now().UTC(),
@ -314,17 +372,49 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
t.Fatalf("create snapshot: %v", err)
}
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
MachineID: "restored",
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
})
if err == nil {
t.Fatal("expected restore rejection while source is running")
defer server.Close()
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
MachineID: "restored",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
SnapshotID: "snap1",
MachineID: "source",
ImageID: "image-1",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
},
},
})
if err != nil {
t.Fatalf("restore snapshot: %v", err)
}
if !strings.Contains(err.Error(), `source machine "source" is running`) {
t.Fatalf("unexpected restore error: %v", err)
if response.Machine.ID != "restored" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
}
if runtime.restoreCalls != 0 {
t.Fatalf("restore boot should not run when source machine is still running: got %d", runtime.restoreCalls)
if runtime.restoreCalls != 1 {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
}
if runtime.lastLoadSpec.Network == nil {
t.Fatalf("restore boot should preserve snapshot network")
}
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
}
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
}
ops, err := fileStore.ListOperations(context.Background())
@ -332,11 +422,11 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
t.Fatalf("list operations: %v", err)
}
if len(ops) != 0 {
t.Fatalf("operation journal should be empty after handled restore rejection: got %d entries", len(ops))
t.Fatalf("operation journal should be empty after successful restore: got %d entries", len(ops))
}
}
func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -378,56 +468,35 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
return nil
}
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
kernelPath := filepath.Join(root, "artifact-kernel")
rootFSPath := filepath.Join(root, "artifact-rootfs")
if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil {
t.Fatalf("write kernel: %v", err)
}
if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil {
t.Fatalf("write rootfs: %v", err)
}
if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{
Ref: artifactRef,
LocalKey: "artifact",
LocalDir: filepath.Join(root, "artifact"),
KernelImagePath: kernelPath,
RootFSPath: rootFSPath,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("put artifact: %v", err)
}
snapDir := filepath.Join(root, "snapshots", "snap1")
if err := os.MkdirAll(snapDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
snapDisk := filepath.Join(snapDir, "system.img")
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
t.Fatalf("write snapshot disk: %v", err)
}
if err := os.WriteFile(filepath.Join(snapDir, "memory.bin"), []byte("mem"), 0o644); err != nil {
t.Fatalf("write memory snapshot: %v", err)
}
if err := os.WriteFile(filepath.Join(snapDir, "vmstate.bin"), []byte("state"), 0o644); err != nil {
t.Fatalf("write vmstate snapshot: %v", err)
}
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap1",
MachineID: "source",
Artifact: artifactRef,
MemFilePath: filepath.Join(snapDir, "memory.bin"),
StateFilePath: filepath.Join(snapDir, "vmstate.bin"),
DiskPaths: []string{snapDisk},
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create snapshot: %v", err)
}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
"/user-0": []byte("user-disk"),
})
defer server.Close()
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
MachineID: "restored",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
SnapshotID: "snap1",
MachineID: "source",
ImageID: "image-1",
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
{ID: "disk-user-0", Kind: contracthost.SnapshotArtifactKindDisk, Name: "user-0.img", DownloadURL: server.URL + "/user-0"},
},
},
})
if err != nil {
t.Fatalf("restore snapshot: %v", err)
@ -439,13 +508,19 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
}
if runtime.lastLoadSpec.Network == nil {
t.Fatal("restore boot did not receive snapshot network")
t.Fatalf("restore boot should preserve durable snapshot network")
}
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
t.Fatalf("restored guest network mismatch: got %q want %q", got, "172.16.0.2")
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
}
if runtime.lastLoadSpec.KernelImagePath != kernelPath {
t.Fatalf("restore boot kernel path mismatch: got %q want %q", runtime.lastLoadSpec.KernelImagePath, kernelPath)
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
}
if !strings.Contains(runtime.lastLoadSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
}), "kernel")) {
t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath)
}
if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" {
t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine)
@ -458,6 +533,12 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
if machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
}
if len(machine.UserVolumeIDs) != 1 {
t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs)
}
if _, err := os.Stat(filepath.Join(cfg.MachineDisksDir, "restored", "user-0.img")); err != nil {
t.Fatalf("restored user disk missing: %v", err)
}
ops, err := fileStore.ListOperations(context.Background())
if err != nil {
@ -468,6 +549,67 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
}
}
func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "source",
Artifact: contracthost.ArtifactRef{KernelImageURL: "https://example.com/kernel", RootFSURL: "https://example.com/rootfs"},
SystemVolumeID: "source-system",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap0",
Phase: contracthost.MachinePhaseRunning,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create running source machine: %v", err)
}
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
MachineID: "restored",
Artifact: contracthost.ArtifactRef{
KernelImageURL: "https://example.com/kernel",
RootFSURL: "https://example.com/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
SnapshotID: "snap1",
MachineID: "source",
ImageID: "image-1",
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
},
})
if err == nil || !strings.Contains(err.Error(), "still in use on this host") {
t.Fatalf("restore snapshot error = %v, want restore network in-use failure", err)
}
if runtime.restoreCalls != 0 {
t.Fatalf("restore boot should not be attempted, got %d calls", runtime.restoreCalls)
}
}
func newRestoreArtifactServer(t *testing.T, payloads map[string][]byte) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
payload, ok := payloads[r.URL.Path]
if !ok {
http.NotFound(w, r)
return
}
_, _ = w.Write(payload)
}))
}
func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) {
t.Parallel()

View file

@ -414,6 +414,8 @@ func publishedPortToContract(record model.PublishedPortRecord) contracthost.Publ
func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState {
phase := firecracker.PhaseStopped
switch record.Phase {
case contracthost.MachinePhaseStarting:
phase = firecracker.PhaseRunning
case contracthost.MachinePhaseRunning:
phase = firecracker.PhaseRunning
case contracthost.MachinePhaseFailed:

View file

@ -54,6 +54,31 @@ hostname "$machine_name" >/dev/null 2>&1 || true
return nil
}
func (d *Daemon) syncGuestFilesystemOverSSH(ctx context.Context, runtimeHost string) error {
runtimeHost = strings.TrimSpace(runtimeHost)
if runtimeHost == "" {
return fmt.Errorf("guest runtime host is required")
}
cmd := exec.CommandContext(
ctx,
"ssh",
"-i", d.backendSSHPrivateKeyPath(),
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "IdentitiesOnly=yes",
"-o", "BatchMode=yes",
"-p", strconv.Itoa(int(defaultSSHPort)),
"node@"+runtimeHost,
"sudo bash -lc "+shellSingleQuote("sync"),
)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("sync guest filesystem over ssh: %w: %s", err, strings.TrimSpace(string(output)))
}
return nil
}
func shellSingleQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}

View file

@ -0,0 +1,80 @@
package daemon
import (
"fmt"
"strings"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const (
defaultMMDSIPv4Address = "169.254.170.2"
defaultMMDSPayloadVersion = "v1"
)
type guestMetadataEnvelope struct {
Latest guestMetadataRoot `json:"latest"`
}
type guestMetadataRoot struct {
MetaData guestMetadataPayload `json:"meta-data"`
}
type guestMetadataPayload struct {
Version string `json:"version"`
MachineID string `json:"machine_id"`
Hostname string `json:"hostname"`
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"`
}
func cloneGuestConfig(config *contracthost.GuestConfig) *contracthost.GuestConfig {
if config == nil {
return nil
}
cloned := &contracthost.GuestConfig{
AuthorizedKeys: append([]string(nil), config.AuthorizedKeys...),
TrustedUserCAKeys: append([]string(nil), config.TrustedUserCAKeys...),
}
if config.LoginWebhook != nil {
copy := *config.LoginWebhook
cloned.LoginWebhook = &copy
}
return cloned
}
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) {
name := strings.TrimSpace(string(machineID))
if name == "" {
return nil, fmt.Errorf("machine id is required")
}
payload := guestMetadataEnvelope{
Latest: guestMetadataRoot{
MetaData: guestMetadataPayload{
Version: defaultMMDSPayloadVersion,
MachineID: name,
Hostname: name,
AuthorizedKeys: nil,
TrustedUserCAKeys: nil,
},
},
}
if guestConfig != nil {
payload.Latest.MetaData.AuthorizedKeys = append([]string(nil), guestConfig.AuthorizedKeys...)
payload.Latest.MetaData.TrustedUserCAKeys = append([]string(nil), guestConfig.TrustedUserCAKeys...)
if guestConfig.LoginWebhook != nil {
loginWebhook := *guestConfig.LoginWebhook
payload.Latest.MetaData.LoginWebhook = &loginWebhook
}
}
return &firecracker.MMDSSpec{
NetworkInterfaces: []string{"net0"},
Version: firecracker.MMDSVersionV2,
IPv4Address: defaultMMDSIPv4Address,
Data: payload,
}, nil
}

View file

@ -48,9 +48,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
if err != nil {
return nil, err
}
previousRecord := *record
if record.Phase == contracthost.MachinePhaseRunning {
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
if record.Phase == contracthost.MachinePhaseStarting {
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
if record.Phase != contracthost.MachinePhaseStopped {
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
}
@ -82,7 +86,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
if err != nil {
return nil, err
}
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path)
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig)
if err != nil {
return nil, err
}
@ -94,39 +98,18 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
if err != nil {
return nil, err
}
ports := defaultMachinePorts()
if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = ports
record.GuestSSHPublicKey = guestSSHPublicKey
record.Phase = contracthost.MachinePhaseRunning
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = ""
record.Phase = contracthost.MachinePhaseStarting
record.Error = ""
record.PID = state.PID
record.SocketPath = state.SocketPath
record.StartedAt = state.StartedAt
if err := d.store.UpdateMachine(ctx, *record); err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
if err := d.ensureMachineRelays(ctx, record); err != nil {
d.stopMachineRelays(id)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
d.stopMachineRelays(id)
d.stopPublishedPortsForMachine(id)
_ = d.runtime.Delete(context.Background(), *state)
_ = d.store.UpdateMachine(context.Background(), previousRecord)
return nil, err
}
@ -272,7 +255,10 @@ func (d *Daemon) listRunningNetworks(ctx context.Context, ignore contracthost.Ma
networks := make([]firecracker.NetworkAllocation, 0, len(records))
for _, record := range records {
if record.ID == ignore || record.Phase != contracthost.MachinePhaseRunning {
if record.ID == ignore {
continue
}
if record.Phase != contracthost.MachinePhaseRunning && record.Phase != contracthost.MachinePhaseStarting {
continue
}
if strings.TrimSpace(record.RuntimeHost) == "" || strings.TrimSpace(record.TapDevice) == "" {
@ -337,11 +323,8 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach
if err != nil {
return err
}
if record.Phase == contracthost.MachinePhaseRunning {
if err := d.ensureMachineRelays(ctx, record); err != nil {
return err
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
if record.Phase == contracthost.MachinePhaseRunning || record.Phase == contracthost.MachinePhaseStarting {
if _, err := d.reconcileMachine(ctx, machineID); err != nil {
return err
}
return d.store.DeleteOperation(ctx, machineID)
@ -385,7 +368,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
if err != nil {
return nil, err
}
if record.Phase != contracthost.MachinePhaseRunning {
if record.Phase != contracthost.MachinePhaseRunning && record.Phase != contracthost.MachinePhaseStarting {
return record, nil
}
@ -393,6 +376,42 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
if err != nil {
return nil, err
}
if record.Phase == contracthost.MachinePhaseStarting {
if state.Phase != firecracker.PhaseRunning {
return d.failMachineStartup(ctx, record, state.Error)
}
ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts())
if err != nil {
return nil, err
}
if !ready {
return record, nil
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
return d.failMachineStartup(ctx, record, err.Error())
}
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = guestSSHPublicKey
record.Phase = contracthost.MachinePhaseRunning
record.Error = ""
record.PID = state.PID
record.SocketPath = state.SocketPath
record.StartedAt = state.StartedAt
if err := d.store.UpdateMachine(ctx, *record); err != nil {
return nil, err
}
if err := d.ensureMachineRelays(ctx, record); err != nil {
return d.failMachineStartup(ctx, record, err.Error())
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
d.stopMachineRelays(record.ID)
return d.failMachineStartup(ctx, record, err.Error())
}
return record, nil
}
if state.Phase == firecracker.PhaseRunning {
if err := d.ensureMachineRelays(ctx, record); err != nil {
return nil, err
@ -418,6 +437,28 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
return record, nil
}
func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRecord, failureReason string) (*model.MachineRecord, error) {
if record == nil {
return nil, fmt.Errorf("machine record is required")
}
_ = d.runtime.Delete(ctx, machineToRuntimeState(*record))
d.stopMachineRelays(record.ID)
d.stopPublishedPortsForMachine(record.ID)
record.Phase = contracthost.MachinePhaseFailed
record.Error = strings.TrimSpace(failureReason)
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = ""
record.PID = 0
record.SocketPath = ""
record.RuntimeHost = ""
record.TapDevice = ""
record.StartedAt = nil
if err := d.store.UpdateMachine(ctx, *record); err != nil {
return nil, err
}
return record, nil
}
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
d.stopMachineRelays(record.ID)
d.stopPublishedPortsForMachine(record.ID)
@ -450,6 +491,11 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
}
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
if record.Phase == contracthost.MachinePhaseRunning && strings.TrimSpace(record.RuntimeHost) != "" {
if err := d.syncGuestFilesystem(ctx, record.RuntimeHost); err != nil {
fmt.Fprintf(os.Stderr, "warning: sync guest filesystem for %q failed before stop: %v\n", record.ID, err)
}
}
d.stopMachineRelays(record.ID)
d.stopPublishedPortsForMachine(record.ID)
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {

View file

@ -56,6 +56,9 @@ func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contractho
if record.ID == machineID {
continue
}
if record.Phase != contracthost.MachinePhaseRunning {
continue
}
if port := machineRelayHostPort(record, name); port != 0 {
used[port] = struct{}{}
}

View file

@ -30,15 +30,15 @@ func waitForGuestReady(ctx context.Context, host string, ports []contracthost.Ma
func waitForGuestPort(ctx context.Context, host string, port contracthost.MachinePort) error {
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
dialer := net.Dialer{Timeout: defaultGuestDialTimeout}
ticker := time.NewTicker(defaultGuestReadyPollInterval)
defer ticker.Stop()
var lastErr error
for {
connection, err := dialer.DialContext(ctx, string(port.Protocol), address)
if err == nil {
_ = connection.Close()
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
ready, err := guestPortReady(probeCtx, host, port)
cancel()
if err == nil && ready {
return nil
}
lastErr = err
@ -50,3 +50,38 @@ func waitForGuestPort(ctx context.Context, host string, port contracthost.Machin
}
}
}
func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) {
host = strings.TrimSpace(host)
if host == "" {
return false, fmt.Errorf("guest runtime host is required")
}
for _, port := range ports {
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
ready, err := guestPortReady(probeCtx, host, port)
cancel()
if err != nil {
return false, err
}
if !ready {
return false, nil
}
}
return true, nil
}
func guestPortReady(ctx context.Context, host string, port contracthost.MachinePort) (bool, error) {
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
dialer := net.Dialer{Timeout: defaultGuestDialTimeout}
connection, err := dialer.DialContext(ctx, string(port.Protocol), address)
if err == nil {
_ = connection.Close()
return true, nil
}
if ctx.Err() != nil {
return false, nil
}
return false, nil
}

View file

@ -2,7 +2,10 @@ package daemon
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net"
"os"
"path/filepath"
@ -54,6 +57,19 @@ func (s machineLookupErrorStore) GetMachine(context.Context, contracthost.Machin
return nil, s.err
}
type relayExhaustionStore struct {
hoststore.Store
extraMachines []model.MachineRecord
}
func (s relayExhaustionStore) ListMachines(ctx context.Context) ([]model.MachineRecord, error) {
machines, err := s.Store.ListMachines(ctx)
if err != nil {
return nil, err
}
return append(machines, s.extraMachines...), nil
}
type publishedPortResult struct {
response *contracthost.CreatePublishedPortResponse
err error
@ -283,6 +299,249 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
assertOperationCount(t, baseStore, 1)
}
func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
exhaustedStore := relayExhaustionStore{
Store: baseStore,
extraMachines: exhaustedMachineRelayRecords(),
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() {
_ = sshListener.Close()
}()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() {
_ = vncListener.Close()
}()
startedAt := time.Unix(1700000200, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
ID: "vm-start",
Phase: firecracker.PhaseRunning,
PID: 9999,
RuntimeHost: "127.0.0.1",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-start", "root", "run", "firecracker.sock"),
TapName: "fctap-start",
StartedAt: &startedAt,
},
}
hostDaemon, err := New(cfg, exhaustedStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
kernelPath := filepath.Join(root, "artifact-kernel")
rootFSPath := filepath.Join(root, "artifact-rootfs")
systemVolumePath := filepath.Join(root, "machine-disks", "vm-start", "rootfs.ext4")
for _, file := range []string{kernelPath, rootFSPath, systemVolumePath} {
if err := os.MkdirAll(filepath.Dir(file), 0o755); err != nil {
t.Fatalf("mkdir for %q: %v", file, err)
}
if err := os.WriteFile(file, []byte("payload"), 0o644); err != nil {
t.Fatalf("write file %q: %v", file, err)
}
}
if err := baseStore.PutArtifact(context.Background(), model.ArtifactRecord{
Ref: artifactRef,
LocalKey: "artifact",
LocalDir: filepath.Join(root, "artifact"),
KernelImagePath: kernelPath,
RootFSPath: rootFSPath,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("put artifact: %v", err)
}
if err := baseStore.CreateVolume(context.Background(), model.VolumeRecord{
ID: "vm-start-system",
Kind: contracthost.VolumeKindSystem,
Pool: model.StoragePoolMachineDisks,
Path: systemVolumePath,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create system volume: %v", err)
}
if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "vm-start",
Artifact: artifactRef,
SystemVolumeID: "vm-start-system",
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseStopped,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create machine: %v", err)
}
response, err := hostDaemon.StartMachine(context.Background(), "vm-start")
if err != nil {
t.Fatalf("StartMachine error = %v", err)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
}
machine, err := baseStore.GetMachine(context.Background(), "vm-start")
if err != nil {
t.Fatalf("get machine: %v", err)
}
if machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting)
}
if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" {
t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice)
}
if machine.PID != 9999 || machine.SocketPath == "" || machine.StartedAt == nil {
t.Fatalf("machine process state missing: pid=%d socket=%q started_at=%v", machine.PID, machine.SocketPath, machine.StartedAt)
}
if len(runtime.deleteCalls) != 0 {
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
}
}
func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
exhaustedStore := relayExhaustionStore{
Store: baseStore,
extraMachines: exhaustedMachineRelayRecords(),
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() {
_ = sshListener.Close()
}()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() {
_ = vncListener.Close()
}()
startedAt := time.Unix(1700000300, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
ID: "restored-exhausted",
Phase: firecracker.PhaseRunning,
PID: 8888,
RuntimeHost: "127.0.0.1",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-exhausted", "root", "run", "firecracker.sock"),
TapName: "fctap-restore",
StartedAt: &startedAt,
},
}
hostDaemon, err := New(cfg, exhaustedStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
kernelPath := filepath.Join(root, "artifact-kernel")
rootFSPath := filepath.Join(root, "artifact-rootfs")
if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil {
t.Fatalf("write kernel: %v", err)
}
if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil {
t.Fatalf("write rootfs: %v", err)
}
if err := baseStore.PutArtifact(context.Background(), model.ArtifactRecord{
Ref: artifactRef,
LocalKey: "artifact",
LocalDir: filepath.Join(root, "artifact"),
KernelImagePath: kernelPath,
RootFSPath: rootFSPath,
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("put artifact: %v", err)
}
snapDir := filepath.Join(root, "snapshots", "snap-exhausted")
if err := os.MkdirAll(snapDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
snapDisk := filepath.Join(snapDir, "system.img")
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
t.Fatalf("write snapshot disk: %v", err)
}
memPath := filepath.Join(snapDir, "memory.bin")
if err := os.WriteFile(memPath, []byte("mem"), 0o644); err != nil {
t.Fatalf("write memory snapshot: %v", err)
}
statePath := filepath.Join(snapDir, "vmstate.bin")
if err := os.WriteFile(statePath, []byte("state"), 0o644); err != nil {
t.Fatalf("write vmstate snapshot: %v", err)
}
if err := baseStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap-exhausted",
MachineID: "source",
Artifact: artifactRef,
MemFilePath: memPath,
StateFilePath: statePath,
DiskPaths: []string{snapDisk},
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create snapshot: %v", err)
}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
})
defer server.Close()
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{
MachineID: "restored-exhausted",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
SnapshotID: "snap-exhausted",
MachineID: "source",
ImageID: "image-1",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
},
},
})
if err == nil || !strings.Contains(err.Error(), "allocate relay ports for restored machine") {
t.Fatalf("RestoreSnapshot error = %v, want relay allocation failure", err)
}
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); !errors.Is(err, hoststore.ErrNotFound) {
t.Fatalf("restored system volume record should be deleted, get err = %v", err)
}
if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); !os.IsNotExist(err) {
t.Fatalf("restored system disk should be removed, stat err = %v", err)
}
if len(runtime.deleteCalls) != 1 {
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
}
assertOperationCount(t, baseStore, 0)
}
func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifacts(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -339,6 +598,193 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact
assertOperationCount(t, fileStore, 0)
}
func TestStopMachineContinuesWhenGuestSyncFails(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.syncGuestFilesystem = func(context.Context, string) error {
return errors.New("guest sync failed")
}
now := time.Now().UTC()
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "vm-stop-fail",
SystemVolumeID: "vm-stop-fail-system",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap-stop-fail",
Phase: contracthost.MachinePhaseRunning,
PID: 1234,
SocketPath: filepath.Join(root, "runtime", "vm-stop-fail.sock"),
Ports: defaultMachinePorts(),
CreatedAt: now,
StartedAt: &now,
}); err != nil {
t.Fatalf("create machine: %v", err)
}
if err := hostDaemon.StopMachine(context.Background(), "vm-stop-fail"); err != nil {
t.Fatalf("StopMachine returned error despite sync failure: %v", err)
}
if len(runtime.deleteCalls) != 1 {
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
}
machine, err := fileStore.GetMachine(context.Background(), "vm-stop-fail")
if err != nil {
t.Fatalf("get machine: %v", err)
}
if machine.Phase != contracthost.MachinePhaseStopped {
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStopped)
}
}
func TestOrderedRestoredUserDiskArtifactsSortsByDriveIndex(t *testing.T) {
ordered := orderedRestoredUserDiskArtifacts(map[string]restoredSnapshotArtifact{
"user-10.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-10.img"}},
"user-2.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-2.img"}},
"user-1.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-1.img"}},
"system.img": {Artifact: contracthost.SnapshotArtifact{Name: "system.img"}},
})
names := make([]string, 0, len(ordered))
for _, artifact := range ordered {
names = append(names, artifact.Artifact.Name)
}
if got, want := strings.Join(names, ","), "user-1.img,user-2.img,user-10.img"; got != want {
t.Fatalf("ordered restored artifacts = %q, want %q", got, want)
}
}
func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() { _ = sshListener.Close() }()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer func() { _ = vncListener.Close() }()
startedAt := time.Unix(1700000400, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
ID: "restored-clean",
Phase: firecracker.PhaseRunning,
PID: 7777,
RuntimeHost: "127.0.0.1",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-clean", "root", "run", "firecracker.sock"),
TapName: "fctap-clean",
StartedAt: &startedAt,
},
}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
"/vmstate": []byte("state"),
"/system": []byte("disk"),
})
defer server.Close()
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-clean", contracthost.RestoreSnapshotRequest{
MachineID: "restored-clean",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
SnapshotID: "snap-clean",
MachineID: "source",
ImageID: "image-1",
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate", SHA256Hex: mustSHA256Hex(t, []byte("state"))},
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system", SHA256Hex: mustSHA256Hex(t, []byte("disk"))},
},
},
})
if err != nil {
t.Fatalf("RestoreSnapshot returned error: %v", err)
}
stagingDir := filepath.Join(cfg.SnapshotsDir, "snap-clean", "restores", "restored-clean")
if _, statErr := os.Stat(stagingDir); !os.IsNotExist(statErr) {
t.Fatalf("restore staging dir should be cleaned up, stat err = %v", statErr)
}
}
func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/memory": []byte("mem"),
})
defer server.Close()
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-fail-clean", contracthost.RestoreSnapshotRequest{
MachineID: "restored-fail-clean",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: contracthost.DurableSnapshotSpec{
SnapshotID: "snap-fail-clean",
MachineID: "source",
ImageID: "image-1",
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))},
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/missing"},
},
},
})
if err == nil || !strings.Contains(err.Error(), "download durable snapshot artifacts") {
t.Fatalf("RestoreSnapshot error = %v, want durable artifact download failure", err)
}
stagingDir := filepath.Join(cfg.SnapshotsDir, "snap-fail-clean", "restores", "restored-fail-clean")
if _, statErr := os.Stat(stagingDir); !os.IsNotExist(statErr) {
t.Fatalf("restore staging dir should be cleaned up after download failure, stat err = %v", statErr)
}
}
func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -423,6 +869,26 @@ func waitPublishedPortResult(t *testing.T, ch <-chan publishedPortResult) publis
}
}
func exhaustedMachineRelayRecords() []model.MachineRecord {
count := int(maxMachineSSHRelayPort-minMachineSSHRelayPort) + 1
machines := make([]model.MachineRecord, 0, count)
for i := 0; i < count; i++ {
machines = append(machines, model.MachineRecord{
ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)),
Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i)),
Phase: contracthost.MachinePhaseRunning,
})
}
return machines
}
func mustSHA256Hex(t *testing.T, payload []byte) string {
t.Helper()
sum := sha256.Sum256(payload)
return hex.EncodeToString(sum[:])
}
func assertOperationCount(t *testing.T, store hoststore.Store, want int) {
t.Helper()

View file

@ -7,6 +7,8 @@ import (
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
@ -108,6 +110,22 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
return nil, fmt.Errorf("copy system disk: %w", err)
}
diskPaths = append(diskPaths, systemDiskTarget)
for i, volumeID := range record.UserVolumeIDs {
volume, err := d.store.GetVolume(ctx, volumeID)
if err != nil {
_ = d.runtime.Resume(ctx, runtimeState)
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("get attached volume %q: %w", volumeID, err)
}
driveID := fmt.Sprintf("user-%d", i)
targetPath := filepath.Join(snapshotDir, driveID+".img")
if err := cowCopyFile(volume.Path, targetPath); err != nil {
_ = d.runtime.Resume(ctx, runtimeState)
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("copy attached volume %q: %w", volumeID, err)
}
diskPaths = append(diskPaths, targetPath)
}
// Resume the source VM
if err := d.runtime.Resume(ctx, runtimeState); err != nil {
@ -132,6 +150,12 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
return nil, fmt.Errorf("move vmstate file: %w", err)
}
artifacts, err := buildSnapshotArtifacts(dstMemPath, dstStatePath, diskPaths)
if err != nil {
_ = os.RemoveAll(snapshotDir)
return nil, fmt.Errorf("build snapshot artifacts: %w", err)
}
now := time.Now().UTC()
snapshotRecord := model.SnapshotRecord{
ID: snapshotID,
@ -140,6 +164,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
MemFilePath: dstMemPath,
StateFilePath: dstStatePath,
DiskPaths: diskPaths,
Artifacts: artifacts,
SourceRuntimeHost: record.RuntimeHost,
SourceTapDevice: record.TapDevice,
CreatedAt: now,
@ -151,25 +176,60 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
clearOperation = true
return &contracthost.CreateSnapshotResponse{
Snapshot: snapshotToContract(snapshotRecord),
Snapshot: snapshotToContract(snapshotRecord),
Artifacts: snapshotArtifactsToContract(snapshotRecord.Artifacts),
}, nil
}
func (d *Daemon) UploadSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error) {
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil {
return nil, err
}
artifactIndex := make(map[string]model.SnapshotArtifactRecord, len(snapshot.Artifacts))
for _, artifact := range snapshot.Artifacts {
artifactIndex[artifact.ID] = artifact
}
response := &contracthost.UploadSnapshotResponse{
Artifacts: make([]contracthost.UploadedSnapshotArtifact, 0, len(req.Artifacts)),
}
for _, upload := range req.Artifacts {
artifact, ok := artifactIndex[upload.ArtifactID]
if !ok {
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID)
}
completedParts, err := uploadSnapshotArtifact(ctx, artifact.LocalPath, upload.Parts)
if err != nil {
return nil, fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err)
}
response.Artifacts = append(response.Artifacts, contracthost.UploadedSnapshotArtifact{
ArtifactID: upload.ArtifactID,
CompletedParts: completedParts,
})
}
return response, nil
}
func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) {
if err := validateMachineID(req.MachineID); err != nil {
return nil, err
}
if req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID)
}
if err := validateArtifactRef(req.Artifact); err != nil {
return nil, err
}
unlock := d.lockMachine(req.MachineID)
defer unlock()
snap, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil {
return nil, err
}
if _, err := d.store.GetMachine(ctx, req.MachineID); err == nil {
return nil, fmt.Errorf("machine %q already exists", req.MachineID)
} else if err != nil && err != store.ErrNotFound {
return nil, err
}
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
@ -188,55 +248,90 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
}()
sourceMachine, err := d.store.GetMachine(ctx, snap.MachineID)
switch {
case err == nil && sourceMachine.Phase == contracthost.MachinePhaseRunning:
clearOperation = true
return nil, fmt.Errorf("restore from snapshot %q while source machine %q is running is not supported yet", snapshotID, snap.MachineID)
case err != nil && err != store.ErrNotFound:
return nil, fmt.Errorf("get source machine for restore: %w", err)
}
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
if err != nil {
return nil, err
}
restoreNetwork, err := restoreNetworkFromSnapshot(snap)
restoreNetwork, err := d.resolveRestoreNetwork(ctx, snapshotID, req.Snapshot)
if err != nil {
clearOperation = true
return nil, err
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
clearOperation = true
return nil, fmt.Errorf("snapshot %q restore network %q (%s) is already in use", snapshotID, restoreNetwork.TapName, restoreNetwork.GuestIP())
return nil, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
}
artifact, err := d.store.GetArtifact(ctx, snap.Artifact)
artifact, err := d.ensureArtifact(ctx, req.Artifact)
if err != nil {
return nil, fmt.Errorf("get artifact for restore: %w", err)
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
}
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
restoredArtifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil {
_ = os.RemoveAll(stagingDir)
clearOperation = true
return nil, fmt.Errorf("download durable snapshot artifacts: %w", err)
}
defer func() { _ = os.RemoveAll(stagingDir) }()
// COW-copy system disk from snapshot to new machine's disk dir.
newSystemDiskPath := d.systemVolumePath(req.MachineID)
if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil {
return nil, fmt.Errorf("create machine disk dir: %w", err)
}
if len(snap.DiskPaths) < 1 {
systemDiskPath, ok := restoredArtifacts["system.img"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q has no disk paths", snapshotID)
return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID)
}
if err := cowCopyFile(snap.DiskPaths[0], newSystemDiskPath); err != nil {
memoryArtifact, ok := restoredArtifacts["memory.bin"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing memory artifact", snapshotID)
}
vmstateArtifact, ok := restoredArtifacts["vmstate.bin"]
if !ok {
clearOperation = true
return nil, fmt.Errorf("snapshot %q is missing vmstate artifact", snapshotID)
}
if err := cowCopyFile(systemDiskPath.LocalPath, newSystemDiskPath); err != nil {
clearOperation = true
return nil, fmt.Errorf("copy system disk for restore: %w", err)
}
type restoredUserVolume struct {
ID contracthost.VolumeID
Path string
DriveID string
}
restoredUserVolumes := make([]restoredUserVolume, 0)
restoredDrivePaths := make(map[string]string)
for _, restored := range orderedRestoredUserDiskArtifacts(restoredArtifacts) {
name := restored.Artifact.Name
driveID := strings.TrimSuffix(name, filepath.Ext(name))
volumeID := contracthost.VolumeID(fmt.Sprintf("%s-%s", req.MachineID, driveID))
volumePath := filepath.Join(d.config.MachineDisksDir, string(req.MachineID), name)
if err := cowCopyFile(restored.LocalPath, volumePath); err != nil {
clearOperation = true
return nil, fmt.Errorf("copy restored drive %q: %w", driveID, err)
}
restoredUserVolumes = append(restoredUserVolumes, restoredUserVolume{
ID: volumeID,
Path: volumePath,
DriveID: driveID,
})
restoredDrivePaths[driveID] = volumePath
}
loadSpec := firecracker.SnapshotLoadSpec{
ID: firecracker.MachineID(req.MachineID),
SnapshotPath: snap.StateFilePath,
MemFilePath: snap.MemFilePath,
SnapshotPath: vmstateArtifact.LocalPath,
MemFilePath: memoryArtifact.LocalPath,
RootFSPath: newSystemDiskPath,
KernelImagePath: artifact.KernelImagePath,
DiskPaths: map[string]string{},
DiskPaths: restoredDrivePaths,
Network: &restoreNetwork,
}
@ -275,18 +370,44 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
ID: systemVolumeID,
Kind: contracthost.VolumeKindSystem,
AttachedMachineID: machineIDPtr(req.MachineID),
SourceArtifact: &snap.Artifact,
SourceArtifact: &req.Artifact,
Pool: model.StoragePoolMachineDisks,
Path: newSystemDiskPath,
CreatedAt: now,
}); err != nil {
return nil, err
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("create system volume record for restore: %w", err)
}
restoredUserVolumeIDs := make([]contracthost.VolumeID, 0, len(restoredUserVolumes))
for _, volume := range restoredUserVolumes {
if err := d.store.CreateVolume(ctx, model.VolumeRecord{
ID: volume.ID,
Kind: contracthost.VolumeKindUser,
AttachedMachineID: machineIDPtr(req.MachineID),
SourceArtifact: &req.Artifact,
Pool: model.StoragePoolMachineDisks,
Path: volume.Path,
CreatedAt: now,
}); err != nil {
for _, restoredVolumeID := range restoredUserVolumeIDs {
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("create restored user volume record %q: %w", volume.ID, err)
}
restoredUserVolumeIDs = append(restoredUserVolumeIDs, volume.ID)
}
machineRecord := model.MachineRecord{
ID: req.MachineID,
Artifact: snap.Artifact,
Artifact: req.Artifact,
SystemVolumeID: systemVolumeID,
UserVolumeIDs: restoredUserVolumeIDs,
RuntimeHost: machineState.RuntimeHost,
TapDevice: machineState.TapName,
Ports: defaultMachinePorts(),
@ -306,7 +427,14 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelays(machineRecord.ID)
return nil, err
for _, restoredVolumeID := range restoredUserVolumeIDs {
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("allocate relay ports for restored machine: %w", err)
}
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
startedRelays := true
@ -316,6 +444,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
}()
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
for _, restoredVolumeID := range restoredUserVolumeIDs {
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, err
}
@ -360,12 +495,89 @@ func (d *Daemon) DeleteSnapshotByID(ctx context.Context, snapshotID contracthost
func snapshotToContract(record model.SnapshotRecord) contracthost.Snapshot {
return contracthost.Snapshot{
ID: record.ID,
MachineID: record.MachineID,
CreatedAt: record.CreatedAt,
ID: record.ID,
MachineID: record.MachineID,
SourceRuntimeHost: record.SourceRuntimeHost,
SourceTapDevice: record.SourceTapDevice,
CreatedAt: record.CreatedAt,
}
}
func snapshotArtifactsToContract(artifacts []model.SnapshotArtifactRecord) []contracthost.SnapshotArtifact {
converted := make([]contracthost.SnapshotArtifact, 0, len(artifacts))
for _, artifact := range artifacts {
converted = append(converted, contracthost.SnapshotArtifact{
ID: artifact.ID,
Kind: artifact.Kind,
Name: artifact.Name,
SizeBytes: artifact.SizeBytes,
SHA256Hex: artifact.SHA256Hex,
})
}
return converted
}
func orderedRestoredUserDiskArtifacts(artifacts map[string]restoredSnapshotArtifact) []restoredSnapshotArtifact {
ordered := make([]restoredSnapshotArtifact, 0, len(artifacts))
for name, artifact := range artifacts {
if !strings.HasPrefix(name, "user-") || filepath.Ext(name) != ".img" {
continue
}
ordered = append(ordered, artifact)
}
sort.Slice(ordered, func(i, j int) bool {
iIdx, iOK := restoredUserDiskIndex(ordered[i].Artifact.Name)
jIdx, jOK := restoredUserDiskIndex(ordered[j].Artifact.Name)
switch {
case iOK && jOK && iIdx != jIdx:
return iIdx < jIdx
case iOK != jOK:
return iOK
default:
return ordered[i].Artifact.Name < ordered[j].Artifact.Name
}
})
return ordered
}
func restoredUserDiskIndex(name string) (int, bool) {
if !strings.HasPrefix(name, "user-") || filepath.Ext(name) != ".img" {
return 0, false
}
value := strings.TrimSuffix(strings.TrimPrefix(name, "user-"), filepath.Ext(name))
index, err := strconv.Atoi(value)
if err != nil {
return 0, false
}
return index, true
}
func (d *Daemon) resolveRestoreNetwork(ctx context.Context, snapshotID contracthost.SnapshotID, spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
if network, err := restoreNetworkFromDurableSpec(spec); err == nil {
return network, nil
}
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err == nil {
return restoreNetworkFromSnapshot(snapshot)
}
if err != store.ErrNotFound {
return firecracker.NetworkAllocation{}, err
}
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snapshotID)
}
func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
if strings.TrimSpace(spec.SourceRuntimeHost) == "" || strings.TrimSpace(spec.SourceTapDevice) == "" {
return firecracker.NetworkAllocation{}, fmt.Errorf("durable snapshot spec is missing restore network metadata")
}
network, err := firecracker.AllocationFromGuestIP(spec.SourceRuntimeHost, spec.SourceTapDevice)
if err != nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct durable snapshot %q network: %w", spec.SnapshotID, err)
}
return network, nil
}
func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) {
if snap == nil {
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required")

View file

@ -0,0 +1,194 @@
package daemon
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type restoredSnapshotArtifact struct {
Artifact contracthost.SnapshotArtifact
LocalPath string
}
func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) {
artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2)
memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, memoryArtifact)
vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, vmstateArtifact)
for _, diskPath := range diskPaths {
base := filepath.Base(diskPath)
diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, diskArtifact)
}
sort.Slice(artifacts, func(i, j int) bool {
return artifacts[i].ID < artifacts[j].ID
})
return artifacts, nil
}
func snapshotArtifactRecord(id string, kind contracthost.SnapshotArtifactKind, name, path string) (model.SnapshotArtifactRecord, error) {
size, err := fileSize(path)
if err != nil {
return model.SnapshotArtifactRecord{}, err
}
sum, err := sha256File(path)
if err != nil {
return model.SnapshotArtifactRecord{}, err
}
return model.SnapshotArtifactRecord{
ID: id,
Kind: kind,
Name: name,
LocalPath: path,
SizeBytes: size,
SHA256Hex: sum,
}, nil
}
func sha256File(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", fmt.Errorf("open %q for sha256: %w", path, err)
}
defer func() { _ = file.Close() }()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", fmt.Errorf("hash %q: %w", path, err)
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
func uploadSnapshotArtifact(ctx context.Context, localPath string, parts []contracthost.SnapshotUploadPart) ([]contracthost.UploadedSnapshotPart, error) {
if len(parts) == 0 {
return nil, fmt.Errorf("upload session has no parts")
}
file, err := os.Open(localPath)
if err != nil {
return nil, fmt.Errorf("open artifact %q: %w", localPath, err)
}
defer func() { _ = file.Close() }()
client := &http.Client{}
completed := make([]contracthost.UploadedSnapshotPart, 0, len(parts))
for _, part := range parts {
reader := io.NewSectionReader(file, part.OffsetBytes, part.SizeBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, part.UploadURL, io.NopCloser(reader))
if err != nil {
return nil, fmt.Errorf("build upload part %d: %w", part.PartNumber, err)
}
req.ContentLength = part.SizeBytes
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("upload part %d: %w", part.PartNumber, err)
}
_ = resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("upload part %d returned %d", part.PartNumber, resp.StatusCode)
}
etag := strings.TrimSpace(resp.Header.Get("ETag"))
if etag == "" {
return nil, fmt.Errorf("upload part %d returned empty etag", part.PartNumber)
}
completed = append(completed, contracthost.UploadedSnapshotPart{
PartNumber: part.PartNumber,
ETag: etag,
})
}
sort.Slice(completed, func(i, j int) bool {
return completed[i].PartNumber < completed[j].PartNumber
})
return completed, nil
}
func downloadDurableSnapshotArtifacts(ctx context.Context, root string, artifacts []contracthost.SnapshotArtifact) (map[string]restoredSnapshotArtifact, error) {
if len(artifacts) == 0 {
return nil, fmt.Errorf("restore snapshot is missing artifacts")
}
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, fmt.Errorf("create restore staging dir %q: %w", root, err)
}
client := &http.Client{}
restored := make(map[string]restoredSnapshotArtifact, len(artifacts))
for _, artifact := range artifacts {
if strings.TrimSpace(artifact.DownloadURL) == "" {
return nil, fmt.Errorf("artifact %q is missing download url", artifact.ID)
}
localPath := filepath.Join(root, artifact.Name)
if err := downloadSnapshotArtifact(ctx, client, artifact.DownloadURL, localPath); err != nil {
return nil, err
}
if expectedSHA := strings.TrimSpace(artifact.SHA256Hex); expectedSHA != "" {
actualSHA, err := sha256File(localPath)
if err != nil {
return nil, err
}
if !strings.EqualFold(actualSHA, expectedSHA) {
_ = os.Remove(localPath)
return nil, fmt.Errorf("restore artifact %q sha256 mismatch: got %s want %s", artifact.Name, actualSHA, expectedSHA)
}
}
restored[artifact.Name] = restoredSnapshotArtifact{
Artifact: artifact,
LocalPath: localPath,
}
}
return restored, nil
}
func downloadSnapshotArtifact(ctx context.Context, client *http.Client, sourceURL, targetPath string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
if err != nil {
return fmt.Errorf("build restore download request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("download durable snapshot artifact: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("download durable snapshot artifact returned %d", resp.StatusCode)
}
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("create restore artifact dir %q: %w", filepath.Dir(targetPath), err)
}
out, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("create restore artifact %q: %w", targetPath, err)
}
defer func() { _ = out.Close() }()
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("write restore artifact %q: %w", targetPath, err)
}
return nil
}

View file

@ -0,0 +1,60 @@
package daemon
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func TestUploadSnapshotArtifactRejectsEmptyETag(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("unexpected method %q", r.Method)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
artifactPath := filepath.Join(t.TempDir(), "artifact.bin")
if err := os.WriteFile(artifactPath, []byte("payload"), 0o644); err != nil {
t.Fatalf("write artifact: %v", err)
}
_, err := uploadSnapshotArtifact(context.Background(), artifactPath, []contracthost.SnapshotUploadPart{{
PartNumber: 1,
OffsetBytes: 0,
SizeBytes: int64(len("payload")),
UploadURL: server.URL,
}})
if err == nil || !strings.Contains(err.Error(), "empty etag") {
t.Fatalf("uploadSnapshotArtifact error = %v, want empty etag failure", err)
}
}
func TestDownloadDurableSnapshotArtifactsRejectsSHA256Mismatch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("payload"))
}))
defer server.Close()
root := t.TempDir()
_, err := downloadDurableSnapshotArtifacts(context.Background(), root, []contracthost.SnapshotArtifact{{
ID: "memory",
Kind: contracthost.SnapshotArtifactKindMemory,
Name: "memory.bin",
DownloadURL: server.URL,
SHA256Hex: strings.Repeat("0", 64),
}})
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
t.Fatalf("downloadDurableSnapshotArtifacts error = %v, want sha256 mismatch", err)
}
if _, statErr := os.Stat(filepath.Join(root, "memory.bin")); !os.IsNotExist(statErr) {
t.Fatalf("corrupt artifact should be removed, stat err = %v", statErr)
}
}

View file

@ -52,19 +52,27 @@ func (d *Daemon) GetStorageReport(ctx context.Context) (*contracthost.GetStorage
}
}
machineUsage := make([]contracthost.MachineStorageUsage, 0, len(volumes))
machineUsageByID := make(map[contracthost.MachineID]contracthost.MachineStorageUsage)
for _, volume := range volumes {
if volume.AttachedMachineID == nil || volume.Kind != contracthost.VolumeKindSystem {
if volume.AttachedMachineID == nil {
continue
}
bytes, err := fileSize(volume.Path)
if err != nil {
return nil, err
}
machineUsage = append(machineUsage, contracthost.MachineStorageUsage{
MachineID: *volume.AttachedMachineID,
SystemBytes: bytes,
})
usage := machineUsageByID[*volume.AttachedMachineID]
usage.MachineID = *volume.AttachedMachineID
if volume.Kind == contracthost.VolumeKindSystem {
usage.SystemBytes += bytes
} else {
usage.UserBytes += bytes
}
machineUsageByID[*volume.AttachedMachineID] = usage
}
machineUsage := make([]contracthost.MachineStorageUsage, 0, len(machineUsageByID))
for _, usage := range machineUsageByID {
machineUsage = append(machineUsage, usage)
}
snapshotUsage := make([]contracthost.SnapshotStorageUsage, 0, len(snapshots))

View file

@ -27,10 +27,12 @@ type bootSourceRequest struct {
}
type driveRequest struct {
DriveID string `json:"drive_id"`
IsReadOnly bool `json:"is_read_only"`
IsRootDevice bool `json:"is_root_device"`
PathOnHost string `json:"path_on_host"`
DriveID string `json:"drive_id"`
IsReadOnly bool `json:"is_read_only"`
IsRootDevice bool `json:"is_root_device"`
PathOnHost string `json:"path_on_host"`
CacheType DriveCacheType `json:"cache_type,omitempty"`
IOEngine DriveIOEngine `json:"io_engine,omitempty"`
}
type entropyRequest struct{}
@ -58,6 +60,13 @@ type networkInterfaceRequest struct {
IfaceID string `json:"iface_id"`
}
type mmdsConfigRequest struct {
IPv4Address string `json:"ipv4_address,omitempty"`
NetworkInterfaces []string `json:"network_interfaces"`
Version MMDSVersion `json:"version,omitempty"`
IMDSCompat bool `json:"imds_compat,omitempty"`
}
type serialRequest struct {
SerialOutPath string `json:"serial_out_path"`
}
@ -127,6 +136,24 @@ func (c *apiClient) PutNetworkInterface(ctx context.Context, network NetworkAllo
return c.do(ctx, http.MethodPut, endpoint, body, nil, http.StatusNoContent)
}
func (c *apiClient) PutMMDSConfig(ctx context.Context, spec MMDSSpec) error {
body := mmdsConfigRequest{
IPv4Address: strings.TrimSpace(spec.IPv4Address),
NetworkInterfaces: append([]string(nil), spec.NetworkInterfaces...),
Version: spec.Version,
IMDSCompat: spec.IMDSCompat,
}
return c.do(ctx, http.MethodPut, "/mmds/config", body, nil, http.StatusNoContent)
}
func (c *apiClient) PutMMDS(ctx context.Context, data any) error {
return c.do(ctx, http.MethodPut, "/mmds", data, nil, http.StatusNoContent)
}
func (c *apiClient) PatchMMDS(ctx context.Context, data any) error {
return c.do(ctx, http.MethodPatch, "/mmds", data, nil, http.StatusNoContent)
}
func (c *apiClient) PutSerial(ctx context.Context, serialOutPath string) error {
return c.do(
ctx,

View file

@ -83,6 +83,91 @@ func TestConfigureMachineEnablesEntropyAndSerialBeforeStart(t *testing.T) {
}
}
func TestConfigureMachineConfiguresMMDSBeforeStart(t *testing.T) {
var requests []capturedRequest
socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
requests = append(requests, capturedRequest{
Method: r.Method,
Path: r.URL.Path,
Body: string(body),
})
w.WriteHeader(http.StatusNoContent)
})
defer shutdown()
client := newAPIClient(socketPath)
spec := MachineSpec{
ID: "vm-2",
VCPUs: 1,
MemoryMiB: 512,
KernelImagePath: "/kernel",
RootFSPath: "/rootfs",
RootDrive: DriveSpec{
ID: "root_drive",
Path: "/rootfs",
CacheType: DriveCacheTypeUnsafe,
IOEngine: DriveIOEngineSync,
},
MMDS: &MMDSSpec{
NetworkInterfaces: []string{"net0"},
Version: MMDSVersionV2,
IPv4Address: "169.254.169.254",
Data: map[string]any{
"latest": map[string]any{
"meta-data": map[string]any{
"microagent": map[string]any{"hostname": "vm-2"},
},
},
},
},
}
paths := machinePaths{JailedSerialLogPath: "/logs/serial.log"}
network := NetworkAllocation{
InterfaceID: defaultInterfaceID,
TapName: "fctap0",
GuestMAC: "06:00:ac:10:00:02",
}
if err := configureMachine(context.Background(), client, paths, spec, network); err != nil {
t.Fatalf("configure machine: %v", err)
}
gotPaths := make([]string, 0, len(requests))
for _, request := range requests {
gotPaths = append(gotPaths, request.Path)
}
wantPaths := []string{
"/machine-config",
"/boot-source",
"/drives/root_drive",
"/network-interfaces/net0",
"/mmds/config",
"/mmds",
"/entropy",
"/serial",
"/actions",
}
if len(gotPaths) != len(wantPaths) {
t.Fatalf("request count mismatch: got %d want %d (%v)", len(gotPaths), len(wantPaths), gotPaths)
}
for i := range wantPaths {
if gotPaths[i] != wantPaths[i] {
t.Fatalf("request %d mismatch: got %q want %q", i, gotPaths[i], wantPaths[i])
}
}
if requests[2].Body != "{\"drive_id\":\"root_drive\",\"is_read_only\":false,\"is_root_device\":true,\"path_on_host\":\"/rootfs\",\"cache_type\":\"Unsafe\",\"io_engine\":\"Sync\"}" {
t.Fatalf("root drive body mismatch: got %q", requests[2].Body)
}
if requests[4].Body != "{\"ipv4_address\":\"169.254.169.254\",\"network_interfaces\":[\"net0\"],\"version\":\"V2\"}" {
t.Fatalf("mmds config body mismatch: got %q", requests[4].Body)
}
}
func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) {
t.Helper()

View file

@ -39,6 +39,16 @@ func configureMachine(ctx context.Context, client *apiClient, paths machinePaths
if err := client.PutNetworkInterface(ctx, network); err != nil {
return fmt.Errorf("put network interface: %w", err)
}
if spec.MMDS != nil {
if err := client.PutMMDSConfig(ctx, *spec.MMDS); err != nil {
return fmt.Errorf("put mmds config: %w", err)
}
if spec.MMDS.Data != nil {
if err := client.PutMMDS(ctx, spec.MMDS.Data); err != nil {
return fmt.Errorf("put mmds payload: %w", err)
}
}
}
if err := client.PutEntropy(ctx); err != nil {
return fmt.Errorf("put entropy device: %w", err)
}
@ -97,12 +107,14 @@ func stageMachineFiles(spec MachineSpec, paths machinePaths) (MachineSpec, error
rootFSPath, err := stagedFileName(spec.RootFSPath)
if err != nil {
return MachineSpec{}, fmt.Errorf("rootfs path: %w", err)
return MachineSpec{}, fmt.Errorf("root drive path: %w", err)
}
if err := linkMachineFile(spec.RootFSPath, filepath.Join(paths.ChrootRootDir, rootFSPath)); err != nil {
return MachineSpec{}, fmt.Errorf("link rootfs into jail: %w", err)
return MachineSpec{}, fmt.Errorf("link root drive into jail: %w", err)
}
staged.RootFSPath = rootFSPath
staged.RootDrive = spec.rootDrive()
staged.RootDrive.Path = rootFSPath
staged.Drives = make([]DriveSpec, len(spec.Drives))
for i, drive := range spec.Drives {
@ -174,6 +186,8 @@ func additionalDriveRequests(spec MachineSpec) []driveRequest {
IsReadOnly: drive.ReadOnly,
IsRootDevice: false,
PathOnHost: drive.Path,
CacheType: drive.CacheType,
IOEngine: drive.IOEngine,
})
}
return requests
@ -249,11 +263,14 @@ func linkMachineFile(source string, target string) error {
}
func rootDriveRequest(spec MachineSpec) driveRequest {
root := spec.rootDrive()
return driveRequest{
DriveID: defaultRootDriveID,
IsReadOnly: false,
DriveID: root.ID,
IsReadOnly: root.ReadOnly,
IsRootDevice: true,
PathOnHost: spec.RootFSPath,
PathOnHost: root.Path,
CacheType: root.CacheType,
IOEngine: root.IOEngine,
}
}

View file

@ -16,16 +16,50 @@ type MachineSpec struct {
MemoryMiB int64
KernelImagePath string
RootFSPath string
RootDrive DriveSpec
KernelArgs string
Drives []DriveSpec
MMDS *MMDSSpec
Vsock *VsockSpec
}
// DriveSpec describes an additional guest block device.
type DriveSpec struct {
ID string
Path string
ReadOnly bool
ID string
Path string
ReadOnly bool
CacheType DriveCacheType
IOEngine DriveIOEngine
}
type DriveCacheType string
const (
DriveCacheTypeUnsafe DriveCacheType = "Unsafe"
DriveCacheTypeWriteback DriveCacheType = "Writeback"
)
type DriveIOEngine string
const (
DriveIOEngineSync DriveIOEngine = "Sync"
DriveIOEngineAsync DriveIOEngine = "Async"
)
type MMDSVersion string
const (
MMDSVersionV1 MMDSVersion = "V1"
MMDSVersionV2 MMDSVersion = "V2"
)
// MMDSSpec describes the MMDS network configuration and initial payload.
type MMDSSpec struct {
NetworkInterfaces []string
Version MMDSVersion
IPv4Address string
IMDSCompat bool
Data any
}
// VsockSpec describes a single host-guest vsock device.
@ -49,17 +83,22 @@ func (s MachineSpec) Validate() error {
if strings.TrimSpace(s.KernelImagePath) == "" {
return fmt.Errorf("machine kernel image path is required")
}
if strings.TrimSpace(s.RootFSPath) == "" {
return fmt.Errorf("machine rootfs path is required")
}
if filepath.Base(strings.TrimSpace(string(s.ID))) != strings.TrimSpace(string(s.ID)) {
return fmt.Errorf("machine id %q must not contain path separators", s.ID)
}
if err := s.rootDrive().Validate(); err != nil {
return fmt.Errorf("root drive: %w", err)
}
for i, drive := range s.Drives {
if err := drive.Validate(); err != nil {
return fmt.Errorf("drive %d: %w", i, err)
}
}
if s.MMDS != nil {
if err := s.MMDS.Validate(); err != nil {
return fmt.Errorf("mmds: %w", err)
}
}
if s.Vsock != nil {
if err := s.Vsock.Validate(); err != nil {
return fmt.Errorf("vsock: %w", err)
@ -70,11 +109,39 @@ func (s MachineSpec) Validate() error {
// Validate reports whether the drive specification is usable.
func (d DriveSpec) Validate() error {
if strings.TrimSpace(d.Path) == "" {
return fmt.Errorf("drive path is required")
}
if strings.TrimSpace(d.ID) == "" {
return fmt.Errorf("drive id is required")
}
if strings.TrimSpace(d.Path) == "" {
return fmt.Errorf("drive path is required")
switch d.CacheType {
case "", DriveCacheTypeUnsafe, DriveCacheTypeWriteback:
default:
return fmt.Errorf("unsupported drive cache type %q", d.CacheType)
}
switch d.IOEngine {
case "", DriveIOEngineSync, DriveIOEngineAsync:
default:
return fmt.Errorf("unsupported drive io engine %q", d.IOEngine)
}
return nil
}
// Validate reports whether the MMDS configuration is usable.
func (m MMDSSpec) Validate() error {
if len(m.NetworkInterfaces) == 0 {
return fmt.Errorf("mmds network interfaces are required")
}
switch m.Version {
case "", MMDSVersionV1, MMDSVersionV2:
default:
return fmt.Errorf("unsupported mmds version %q", m.Version)
}
for i, iface := range m.NetworkInterfaces {
if strings.TrimSpace(iface) == "" {
return fmt.Errorf("mmds network_interfaces[%d] is required", i)
}
}
return nil
}
@ -92,3 +159,14 @@ func (v VsockSpec) Validate() error {
}
return nil
}
func (s MachineSpec) rootDrive() DriveSpec {
root := s.RootDrive
if strings.TrimSpace(root.ID) == "" {
root.ID = defaultRootDriveID
}
if strings.TrimSpace(root.Path) == "" {
root.Path = s.RootFSPath
}
return root
}

View file

@ -20,6 +20,7 @@ type Service interface {
Health(context.Context) (*contracthost.HealthResponse, error)
GetStorageReport(context.Context) (*contracthost.GetStorageReportResponse, error)
CreateSnapshot(context.Context, contracthost.MachineID, contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error)
UploadSnapshot(context.Context, contracthost.SnapshotID, contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error)
ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error)
GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error)
DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error
@ -278,6 +279,25 @@ func (h *Handler) handleSnapshot(w http.ResponseWriter, r *http.Request) {
return
}
if len(parts) == 2 && parts[1] == "upload" {
if r.Method != http.MethodPost {
writeMethodNotAllowed(w)
return
}
var req contracthost.UploadSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
response, err := h.service.UploadSnapshot(r.Context(), snapshotID, req)
if err != nil {
writeError(w, statusForError(err), err)
return
}
writeJSON(w, http.StatusOK, response)
return
}
writeError(w, http.StatusNotFound, fmt.Errorf("route not found"))
}

View file

@ -29,6 +29,7 @@ type ArtifactRecord struct {
type MachineRecord struct {
ID contracthost.MachineID
Artifact contracthost.ArtifactRef
GuestConfig *contracthost.GuestConfig
SystemVolumeID contracthost.VolumeID
UserVolumeIDs []contracthost.VolumeID
RuntimeHost string
@ -71,11 +72,21 @@ type SnapshotRecord struct {
MemFilePath string
StateFilePath string
DiskPaths []string
Artifacts []SnapshotArtifactRecord
SourceRuntimeHost string
SourceTapDevice string
CreatedAt time.Time
}
type SnapshotArtifactRecord struct {
ID string
Kind contracthost.SnapshotArtifactKind
Name string
LocalPath string
SizeBytes int64
SHA256Hex string
}
type PublishedPortRecord struct {
ID contracthost.PublishedPortID
MachineID contracthost.MachineID