mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 01:00:27 +00:00
feat: firecracker mmds identity
This commit is contained in:
parent
500354cd9b
commit
3eb610b703
23 changed files with 1813 additions and 263 deletions
|
|
@ -17,6 +17,7 @@ type Machine struct {
|
|||
}
|
||||
|
||||
type GuestConfig struct {
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
||||
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
|
||||
LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"`
|
||||
|
|
|
|||
|
|
@ -4,10 +4,31 @@ import "time"
|
|||
|
||||
type SnapshotID string
|
||||
|
||||
type SnapshotArtifactKind string
|
||||
|
||||
const (
|
||||
SnapshotArtifactKindMemory SnapshotArtifactKind = "memory"
|
||||
SnapshotArtifactKindVMState SnapshotArtifactKind = "vmstate"
|
||||
SnapshotArtifactKindDisk SnapshotArtifactKind = "disk"
|
||||
SnapshotArtifactKindManifest SnapshotArtifactKind = "manifest"
|
||||
)
|
||||
|
||||
type Snapshot struct {
|
||||
ID SnapshotID `json:"id"`
|
||||
MachineID MachineID `json:"machine_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID SnapshotID `json:"id"`
|
||||
MachineID MachineID `json:"machine_id"`
|
||||
SourceRuntimeHost string `json:"source_runtime_host,omitempty"`
|
||||
SourceTapDevice string `json:"source_tap_device,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type SnapshotArtifact struct {
|
||||
ID string `json:"id"`
|
||||
Kind SnapshotArtifactKind `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
SHA256Hex string `json:"sha256_hex,omitempty"`
|
||||
ObjectKey string `json:"object_key,omitempty"`
|
||||
DownloadURL string `json:"download_url,omitempty"`
|
||||
}
|
||||
|
||||
type CreateSnapshotRequest struct {
|
||||
|
|
@ -15,7 +36,8 @@ type CreateSnapshotRequest struct {
|
|||
}
|
||||
|
||||
type CreateSnapshotResponse struct {
|
||||
Snapshot Snapshot `json:"snapshot"`
|
||||
Snapshot Snapshot `json:"snapshot"`
|
||||
Artifacts []SnapshotArtifact `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
type GetSnapshotResponse struct {
|
||||
|
|
@ -26,8 +48,52 @@ type ListSnapshotsResponse struct {
|
|||
Snapshots []Snapshot `json:"snapshots"`
|
||||
}
|
||||
|
||||
type SnapshotUploadPart struct {
|
||||
PartNumber int32 `json:"part_number"`
|
||||
OffsetBytes int64 `json:"offset_bytes"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
UploadURL string `json:"upload_url"`
|
||||
}
|
||||
|
||||
type SnapshotArtifactUploadSession struct {
|
||||
ArtifactID string `json:"artifact_id"`
|
||||
ObjectKey string `json:"object_key"`
|
||||
UploadID string `json:"upload_id"`
|
||||
Parts []SnapshotUploadPart `json:"parts"`
|
||||
}
|
||||
|
||||
type UploadSnapshotRequest struct {
|
||||
Artifacts []SnapshotArtifactUploadSession `json:"artifacts"`
|
||||
}
|
||||
|
||||
type UploadedSnapshotPart struct {
|
||||
PartNumber int32 `json:"part_number"`
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type UploadedSnapshotArtifact struct {
|
||||
ArtifactID string `json:"artifact_id"`
|
||||
CompletedParts []UploadedSnapshotPart `json:"completed_parts"`
|
||||
}
|
||||
|
||||
type UploadSnapshotResponse struct {
|
||||
Artifacts []UploadedSnapshotArtifact `json:"artifacts"`
|
||||
}
|
||||
|
||||
type RestoreSnapshotRequest struct {
|
||||
MachineID MachineID `json:"machine_id"`
|
||||
MachineID MachineID `json:"machine_id"`
|
||||
Artifact ArtifactRef `json:"artifact"`
|
||||
Snapshot DurableSnapshotSpec `json:"snapshot"`
|
||||
GuestConfig *GuestConfig `json:"guest_config,omitempty"`
|
||||
}
|
||||
|
||||
type DurableSnapshotSpec struct {
|
||||
SnapshotID SnapshotID `json:"snapshot_id"`
|
||||
MachineID MachineID `json:"machine_id"`
|
||||
ImageID string `json:"image_id"`
|
||||
SourceRuntimeHost string `json:"source_runtime_host,omitempty"`
|
||||
SourceTapDevice string `json:"source_tap_device,omitempty"`
|
||||
Artifacts []SnapshotArtifact `json:"artifacts"`
|
||||
}
|
||||
|
||||
type RestoreSnapshotResponse struct {
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ type VolumeID string
|
|||
type VolumeKind string
|
||||
|
||||
const (
|
||||
MachinePhaseRunning MachinePhase = "running"
|
||||
MachinePhaseStopped MachinePhase = "stopped"
|
||||
MachinePhaseFailed MachinePhase = "failed"
|
||||
MachinePhaseStarting MachinePhase = "starting"
|
||||
MachinePhaseRunning MachinePhase = "running"
|
||||
MachinePhaseStopped MachinePhase = "stopped"
|
||||
MachinePhaseFailed MachinePhase = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
if err := os.MkdirAll(filepath.Dir(systemVolumePath), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create system volume dir for %q: %w", req.MachineID, err)
|
||||
}
|
||||
if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil {
|
||||
if err := cowCopyFile(artifact.RootFSPath, systemVolumePath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
removeSystemVolumeOnFailure := true
|
||||
|
|
@ -77,14 +77,8 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
_ = os.Remove(systemVolumePath)
|
||||
_ = os.RemoveAll(filepath.Dir(systemVolumePath))
|
||||
}()
|
||||
if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath)
|
||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -98,17 +92,6 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ports := defaultMachinePorts()
|
||||
if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||
if err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
systemVolumeRecord := model.VolumeRecord{
|
||||
ID: d.systemVolumeID(req.MachineID),
|
||||
|
|
@ -143,44 +126,20 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
}
|
||||
|
||||
record := model.MachineRecord{
|
||||
ID: req.MachineID,
|
||||
Artifact: req.Artifact,
|
||||
SystemVolumeID: systemVolumeRecord.ID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||
RuntimeHost: state.RuntimeHost,
|
||||
TapDevice: state.TapName,
|
||||
Ports: ports,
|
||||
GuestSSHPublicKey: guestSSHPublicKey,
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: state.PID,
|
||||
SocketPath: state.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: state.StartedAt,
|
||||
ID: req.MachineID,
|
||||
Artifact: req.Artifact,
|
||||
GuestConfig: cloneGuestConfig(guestConfig),
|
||||
SystemVolumeID: systemVolumeRecord.ID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||
RuntimeHost: state.RuntimeHost,
|
||||
TapDevice: state.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
Phase: contracthost.MachinePhaseStarting,
|
||||
PID: state.PID,
|
||||
SocketPath: state.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: state.StartedAt,
|
||||
}
|
||||
d.relayAllocMu.Lock()
|
||||
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||
var vncRelayPort uint16
|
||||
if err == nil {
|
||||
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||
}
|
||||
d.relayAllocMu.Unlock()
|
||||
if err != nil {
|
||||
d.stopMachineRelays(record.ID)
|
||||
for _, volume := range userVolumes {
|
||||
volume.AttachedMachineID = nil
|
||||
_ = d.store.UpdateVolume(context.Background(), volume)
|
||||
}
|
||||
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||
startedRelays := true
|
||||
defer func() {
|
||||
if startedRelays {
|
||||
d.stopMachineRelays(record.ID)
|
||||
}
|
||||
}()
|
||||
if err := d.store.CreateMachine(ctx, record); err != nil {
|
||||
for _, volume := range userVolumes {
|
||||
volume.AttachedMachineID = nil
|
||||
|
|
@ -192,12 +151,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
}
|
||||
|
||||
removeSystemVolumeOnFailure = false
|
||||
startedRelays = false
|
||||
clearOperation = true
|
||||
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string) (firecracker.MachineSpec, error) {
|
||||
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) {
|
||||
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
|
||||
for i, volume := range userVolumes {
|
||||
drives = append(drives, firecracker.DriveSpec{
|
||||
|
|
@ -207,14 +165,25 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
|
|||
})
|
||||
}
|
||||
|
||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
|
||||
if err != nil {
|
||||
return firecracker.MachineSpec{}, err
|
||||
}
|
||||
spec := firecracker.MachineSpec{
|
||||
ID: firecracker.MachineID(machineID),
|
||||
VCPUs: defaultGuestVCPUs,
|
||||
MemoryMiB: defaultGuestMemoryMiB,
|
||||
KernelImagePath: artifact.KernelImagePath,
|
||||
RootFSPath: systemVolumePath,
|
||||
KernelArgs: defaultGuestKernelArgs,
|
||||
Drives: drives,
|
||||
RootDrive: firecracker.DriveSpec{
|
||||
ID: "root_drive",
|
||||
Path: systemVolumePath,
|
||||
CacheType: firecracker.DriveCacheTypeUnsafe,
|
||||
IOEngine: firecracker.DriveIOEngineSync,
|
||||
},
|
||||
KernelArgs: defaultGuestKernelArgs,
|
||||
Drives: drives,
|
||||
MMDS: mmds,
|
||||
}
|
||||
if err := spec.Validate(); err != nil {
|
||||
return firecracker.MachineSpec{}, err
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ type Daemon struct {
|
|||
|
||||
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error
|
||||
readGuestSSHPublicKey func(context.Context, string) (string, error)
|
||||
syncGuestFilesystem func(context.Context, string) error
|
||||
|
||||
locksMu sync.Mutex
|
||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||
|
|
@ -84,6 +85,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
|||
}
|
||||
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
||||
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
|
||||
daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH
|
||||
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
t.Fatalf("create machine: %v", err)
|
||||
}
|
||||
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
|
||||
}
|
||||
if response.Machine.RuntimeHost != "127.0.0.1" {
|
||||
|
|
@ -169,29 +169,25 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("read backend ssh public key: %v", err)
|
||||
}
|
||||
authorizedKeys, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/authorized_keys")
|
||||
if err != nil {
|
||||
t.Fatalf("read injected authorized_keys: %v", err)
|
||||
if runtime.lastSpec.MMDS == nil {
|
||||
t.Fatalf("expected MMDS configuration on machine spec")
|
||||
}
|
||||
if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 {
|
||||
t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version)
|
||||
}
|
||||
payload, ok := runtime.lastSpec.MMDS.Data.(guestMetadataEnvelope)
|
||||
if !ok {
|
||||
t.Fatalf("mmds payload type mismatch: got %T", runtime.lastSpec.MMDS.Data)
|
||||
}
|
||||
if payload.Latest.MetaData.Hostname != "vm-1" {
|
||||
t.Fatalf("mmds hostname mismatch: got %q", payload.Latest.MetaData.Hostname)
|
||||
}
|
||||
authorizedKeys := strings.Join(payload.Latest.MetaData.AuthorizedKeys, "\n")
|
||||
if !strings.Contains(authorizedKeys, strings.TrimSpace(string(hostAuthorizedKeyBytes))) {
|
||||
t.Fatalf("authorized_keys missing backend ssh key: %q", authorizedKeys)
|
||||
t.Fatalf("mmds authorized_keys missing backend ssh key: %q", authorizedKeys)
|
||||
}
|
||||
if !strings.Contains(authorizedKeys, "daemon-test") {
|
||||
t.Fatalf("authorized_keys missing request override key: %q", authorizedKeys)
|
||||
}
|
||||
machineName, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/machine-name")
|
||||
if err != nil {
|
||||
t.Fatalf("read injected machine-name: %v", err)
|
||||
}
|
||||
if machineName != "vm-1\n" {
|
||||
t.Fatalf("machine-name mismatch: got %q want %q", machineName, "vm-1\n")
|
||||
}
|
||||
hosts, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/hosts")
|
||||
if err != nil {
|
||||
t.Fatalf("read injected hosts: %v", err)
|
||||
}
|
||||
if !strings.Contains(hosts, "127.0.1.1 vm-1") {
|
||||
t.Fatalf("hosts missing machine identity: %q", hosts)
|
||||
t.Fatalf("mmds authorized_keys missing request override key: %q", authorizedKeys)
|
||||
}
|
||||
|
||||
artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact)
|
||||
|
|
@ -214,6 +210,12 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
if machine.SystemVolumeID != "vm-1-system" {
|
||||
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
|
||||
}
|
||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("stored machine phase mismatch: got %q", machine.Phase)
|
||||
}
|
||||
if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 {
|
||||
t.Fatalf("stored guest config missing authorized keys: %#v", machine.GuestConfig)
|
||||
}
|
||||
|
||||
operations, err := fileStore.ListOperations(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -224,6 +226,65 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
runtime := &fakeRuntime{}
|
||||
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
|
||||
var syncedHost string
|
||||
hostDaemon.syncGuestFilesystem = func(_ context.Context, runtimeHost string) error {
|
||||
syncedHost = runtimeHost
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
ID: "vm-stop",
|
||||
SystemVolumeID: "vm-stop-system",
|
||||
RuntimeHost: "172.16.0.2",
|
||||
TapDevice: "fctap-stop",
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: 1234,
|
||||
SocketPath: filepath.Join(root, "runtime", "vm-stop.sock"),
|
||||
Ports: defaultMachinePorts(),
|
||||
CreatedAt: now,
|
||||
StartedAt: &now,
|
||||
}); err != nil {
|
||||
t.Fatalf("create machine: %v", err)
|
||||
}
|
||||
|
||||
if err := hostDaemon.StopMachine(context.Background(), "vm-stop"); err != nil {
|
||||
t.Fatalf("stop machine: %v", err)
|
||||
}
|
||||
|
||||
if syncedHost != "172.16.0.2" {
|
||||
t.Fatalf("sync host mismatch: got %q want %q", syncedHost, "172.16.0.2")
|
||||
}
|
||||
if len(runtime.deleteCalls) != 1 {
|
||||
t.Fatalf("runtime delete call count mismatch: got %d want 1", len(runtime.deleteCalls))
|
||||
}
|
||||
|
||||
stopped, err := fileStore.GetMachine(context.Background(), "vm-stop")
|
||||
if err != nil {
|
||||
t.Fatalf("get stopped machine: %v", err)
|
||||
}
|
||||
if stopped.Phase != contracthost.MachinePhaseStopped {
|
||||
t.Fatalf("machine phase mismatch: got %q want %q", stopped.Phase, contracthost.MachinePhaseStopped)
|
||||
}
|
||||
if stopped.RuntimeHost != "" {
|
||||
t.Fatalf("runtime host should be cleared after stop, got %q", stopped.RuntimeHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
|
|
@ -249,7 +310,7 @@ func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
||||
func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
|
|
@ -257,7 +318,23 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
|||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
runtime := &fakeRuntime{}
|
||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||
defer func() { _ = sshListener.Close() }()
|
||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||
defer func() { _ = vncListener.Close() }()
|
||||
|
||||
startedAt := time.Unix(1700000099, 0).UTC()
|
||||
runtime := &fakeRuntime{
|
||||
bootState: firecracker.MachineState{
|
||||
ID: "restored",
|
||||
Phase: firecracker.PhaseRunning,
|
||||
PID: 1234,
|
||||
RuntimeHost: "127.0.0.1",
|
||||
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored", "root", "run", "firecracker.sock"),
|
||||
TapName: "fctap0",
|
||||
StartedAt: &startedAt,
|
||||
},
|
||||
}
|
||||
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
|
|
@ -281,32 +358,13 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
|||
t.Fatalf("put artifact: %v", err)
|
||||
}
|
||||
|
||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
ID: "source",
|
||||
Artifact: artifactRef,
|
||||
SystemVolumeID: "source-system",
|
||||
RuntimeHost: "172.16.0.2",
|
||||
TapDevice: "fctap0",
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create source machine: %v", err)
|
||||
}
|
||||
|
||||
snapDisk := filepath.Join(root, "snapshots", "snap1", "system.img")
|
||||
if err := os.MkdirAll(filepath.Dir(snapDisk), 0o755); err != nil {
|
||||
t.Fatalf("create snapshot dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
|
||||
t.Fatalf("write snapshot disk: %v", err)
|
||||
}
|
||||
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
|
||||
ID: "snap1",
|
||||
MachineID: "source",
|
||||
Artifact: artifactRef,
|
||||
MemFilePath: filepath.Join(root, "snapshots", "snap1", "memory.bin"),
|
||||
StateFilePath: filepath.Join(root, "snapshots", "snap1", "vmstate.bin"),
|
||||
DiskPaths: []string{snapDisk},
|
||||
DiskPaths: []string{filepath.Join(root, "snapshots", "snap1", "system.img")},
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
|
|
@ -314,17 +372,49 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
|||
t.Fatalf("create snapshot: %v", err)
|
||||
}
|
||||
|
||||
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored",
|
||||
server := newRestoreArtifactServer(t, map[string][]byte{
|
||||
"/kernel": []byte("kernel"),
|
||||
"/rootfs": []byte("rootfs"),
|
||||
"/memory": []byte("mem"),
|
||||
"/vmstate": []byte("state"),
|
||||
"/system": []byte("disk"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected restore rejection while source is running")
|
||||
defer server.Close()
|
||||
|
||||
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored",
|
||||
Artifact: contracthost.ArtifactRef{
|
||||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
},
|
||||
Snapshot: contracthost.DurableSnapshotSpec{
|
||||
SnapshotID: "snap1",
|
||||
MachineID: "source",
|
||||
ImageID: "image-1",
|
||||
Artifacts: []contracthost.SnapshotArtifact{
|
||||
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
|
||||
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
|
||||
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("restore snapshot: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), `source machine "source" is running`) {
|
||||
t.Fatalf("unexpected restore error: %v", err)
|
||||
if response.Machine.ID != "restored" {
|
||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||
}
|
||||
if runtime.restoreCalls != 0 {
|
||||
t.Fatalf("restore boot should not run when source machine is still running: got %d", runtime.restoreCalls)
|
||||
if runtime.restoreCalls != 1 {
|
||||
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
|
||||
}
|
||||
if runtime.lastLoadSpec.Network == nil {
|
||||
t.Fatalf("restore boot should preserve snapshot network")
|
||||
}
|
||||
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
|
||||
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
|
||||
}
|
||||
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
|
||||
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
|
||||
}
|
||||
|
||||
ops, err := fileStore.ListOperations(context.Background())
|
||||
|
|
@ -332,11 +422,11 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
|||
t.Fatalf("list operations: %v", err)
|
||||
}
|
||||
if len(ops) != 0 {
|
||||
t.Fatalf("operation journal should be empty after handled restore rejection: got %d entries", len(ops))
|
||||
t.Fatalf("operation journal should be empty after successful restore: got %d entries", len(ops))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
||||
func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
|
|
@ -378,56 +468,35 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
||||
kernelPath := filepath.Join(root, "artifact-kernel")
|
||||
rootFSPath := filepath.Join(root, "artifact-rootfs")
|
||||
if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil {
|
||||
t.Fatalf("write kernel: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil {
|
||||
t.Fatalf("write rootfs: %v", err)
|
||||
}
|
||||
if err := fileStore.PutArtifact(context.Background(), model.ArtifactRecord{
|
||||
Ref: artifactRef,
|
||||
LocalKey: "artifact",
|
||||
LocalDir: filepath.Join(root, "artifact"),
|
||||
KernelImagePath: kernelPath,
|
||||
RootFSPath: rootFSPath,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("put artifact: %v", err)
|
||||
}
|
||||
|
||||
snapDir := filepath.Join(root, "snapshots", "snap1")
|
||||
if err := os.MkdirAll(snapDir, 0o755); err != nil {
|
||||
t.Fatalf("create snapshot dir: %v", err)
|
||||
}
|
||||
snapDisk := filepath.Join(snapDir, "system.img")
|
||||
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
|
||||
t.Fatalf("write snapshot disk: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(snapDir, "memory.bin"), []byte("mem"), 0o644); err != nil {
|
||||
t.Fatalf("write memory snapshot: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(snapDir, "vmstate.bin"), []byte("state"), 0o644); err != nil {
|
||||
t.Fatalf("write vmstate snapshot: %v", err)
|
||||
}
|
||||
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
|
||||
ID: "snap1",
|
||||
MachineID: "source",
|
||||
Artifact: artifactRef,
|
||||
MemFilePath: filepath.Join(snapDir, "memory.bin"),
|
||||
StateFilePath: filepath.Join(snapDir, "vmstate.bin"),
|
||||
DiskPaths: []string{snapDisk},
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create snapshot: %v", err)
|
||||
}
|
||||
server := newRestoreArtifactServer(t, map[string][]byte{
|
||||
"/kernel": []byte("kernel"),
|
||||
"/rootfs": []byte("rootfs"),
|
||||
"/memory": []byte("mem"),
|
||||
"/vmstate": []byte("state"),
|
||||
"/system": []byte("disk"),
|
||||
"/user-0": []byte("user-disk"),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored",
|
||||
Artifact: contracthost.ArtifactRef{
|
||||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
},
|
||||
Snapshot: contracthost.DurableSnapshotSpec{
|
||||
SnapshotID: "snap1",
|
||||
MachineID: "source",
|
||||
ImageID: "image-1",
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
Artifacts: []contracthost.SnapshotArtifact{
|
||||
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
|
||||
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
|
||||
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
|
||||
{ID: "disk-user-0", Kind: contracthost.SnapshotArtifactKindDisk, Name: "user-0.img", DownloadURL: server.URL + "/user-0"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("restore snapshot: %v", err)
|
||||
|
|
@ -439,13 +508,19 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
|||
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
|
||||
}
|
||||
if runtime.lastLoadSpec.Network == nil {
|
||||
t.Fatal("restore boot did not receive snapshot network")
|
||||
t.Fatalf("restore boot should preserve durable snapshot network")
|
||||
}
|
||||
if got := runtime.lastLoadSpec.Network.GuestIP().String(); got != "172.16.0.2" {
|
||||
t.Fatalf("restored guest network mismatch: got %q want %q", got, "172.16.0.2")
|
||||
t.Fatalf("restore guest ip mismatch: got %q want %q", got, "172.16.0.2")
|
||||
}
|
||||
if runtime.lastLoadSpec.KernelImagePath != kernelPath {
|
||||
t.Fatalf("restore boot kernel path mismatch: got %q want %q", runtime.lastLoadSpec.KernelImagePath, kernelPath)
|
||||
if got := runtime.lastLoadSpec.Network.TapName; got != "fctap0" {
|
||||
t.Fatalf("restore tap mismatch: got %q want %q", got, "fctap0")
|
||||
}
|
||||
if !strings.Contains(runtime.lastLoadSpec.KernelImagePath, filepath.Join("artifacts", artifactKey(contracthost.ArtifactRef{
|
||||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
}), "kernel")) {
|
||||
t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath)
|
||||
}
|
||||
if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" {
|
||||
t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine)
|
||||
|
|
@ -458,6 +533,12 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
|||
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
|
||||
}
|
||||
if len(machine.UserVolumeIDs) != 1 {
|
||||
t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(cfg.MachineDisksDir, "restored", "user-0.img")); err != nil {
|
||||
t.Fatalf("restored user disk missing: %v", err)
|
||||
}
|
||||
|
||||
ops, err := fileStore.ListOperations(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -468,6 +549,67 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotRejectsWhenRestoreNetworkInUseOnHost(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
runtime := &fakeRuntime{}
|
||||
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
|
||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
ID: "source",
|
||||
Artifact: contracthost.ArtifactRef{KernelImageURL: "https://example.com/kernel", RootFSURL: "https://example.com/rootfs"},
|
||||
SystemVolumeID: "source-system",
|
||||
RuntimeHost: "172.16.0.2",
|
||||
TapDevice: "fctap0",
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create running source machine: %v", err)
|
||||
}
|
||||
|
||||
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap1", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored",
|
||||
Artifact: contracthost.ArtifactRef{
|
||||
KernelImageURL: "https://example.com/kernel",
|
||||
RootFSURL: "https://example.com/rootfs",
|
||||
},
|
||||
Snapshot: contracthost.DurableSnapshotSpec{
|
||||
SnapshotID: "snap1",
|
||||
MachineID: "source",
|
||||
ImageID: "image-1",
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "still in use on this host") {
|
||||
t.Fatalf("restore snapshot error = %v, want restore network in-use failure", err)
|
||||
}
|
||||
if runtime.restoreCalls != 0 {
|
||||
t.Fatalf("restore boot should not be attempted, got %d calls", runtime.restoreCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func newRestoreArtifactServer(t *testing.T, payloads map[string][]byte) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, ok := payloads[r.URL.Path]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(payload)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
|||
|
|
@ -414,6 +414,8 @@ func publishedPortToContract(record model.PublishedPortRecord) contracthost.Publ
|
|||
func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState {
|
||||
phase := firecracker.PhaseStopped
|
||||
switch record.Phase {
|
||||
case contracthost.MachinePhaseStarting:
|
||||
phase = firecracker.PhaseRunning
|
||||
case contracthost.MachinePhaseRunning:
|
||||
phase = firecracker.PhaseRunning
|
||||
case contracthost.MachinePhaseFailed:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,31 @@ hostname "$machine_name" >/dev/null 2>&1 || true
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) syncGuestFilesystemOverSSH(ctx context.Context, runtimeHost string) error {
|
||||
runtimeHost = strings.TrimSpace(runtimeHost)
|
||||
if runtimeHost == "" {
|
||||
return fmt.Errorf("guest runtime host is required")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(
|
||||
ctx,
|
||||
"ssh",
|
||||
"-i", d.backendSSHPrivateKeyPath(),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "IdentitiesOnly=yes",
|
||||
"-o", "BatchMode=yes",
|
||||
"-p", strconv.Itoa(int(defaultSSHPort)),
|
||||
"node@"+runtimeHost,
|
||||
"sudo bash -lc "+shellSingleQuote("sync"),
|
||||
)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync guest filesystem over ssh: %w: %s", err, strings.TrimSpace(string(output)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shellSingleQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
|
||||
}
|
||||
|
|
|
|||
80
internal/daemon/guest_metadata.go
Normal file
80
internal/daemon/guest_metadata.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMMDSIPv4Address = "169.254.170.2"
|
||||
defaultMMDSPayloadVersion = "v1"
|
||||
)
|
||||
|
||||
type guestMetadataEnvelope struct {
|
||||
Latest guestMetadataRoot `json:"latest"`
|
||||
}
|
||||
|
||||
type guestMetadataRoot struct {
|
||||
MetaData guestMetadataPayload `json:"meta-data"`
|
||||
}
|
||||
|
||||
type guestMetadataPayload struct {
|
||||
Version string `json:"version"`
|
||||
MachineID string `json:"machine_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
||||
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
|
||||
LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"`
|
||||
}
|
||||
|
||||
func cloneGuestConfig(config *contracthost.GuestConfig) *contracthost.GuestConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &contracthost.GuestConfig{
|
||||
AuthorizedKeys: append([]string(nil), config.AuthorizedKeys...),
|
||||
TrustedUserCAKeys: append([]string(nil), config.TrustedUserCAKeys...),
|
||||
}
|
||||
if config.LoginWebhook != nil {
|
||||
copy := *config.LoginWebhook
|
||||
cloned.LoginWebhook = ©
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) {
|
||||
name := strings.TrimSpace(string(machineID))
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("machine id is required")
|
||||
}
|
||||
|
||||
payload := guestMetadataEnvelope{
|
||||
Latest: guestMetadataRoot{
|
||||
MetaData: guestMetadataPayload{
|
||||
Version: defaultMMDSPayloadVersion,
|
||||
MachineID: name,
|
||||
Hostname: name,
|
||||
AuthorizedKeys: nil,
|
||||
TrustedUserCAKeys: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
if guestConfig != nil {
|
||||
payload.Latest.MetaData.AuthorizedKeys = append([]string(nil), guestConfig.AuthorizedKeys...)
|
||||
payload.Latest.MetaData.TrustedUserCAKeys = append([]string(nil), guestConfig.TrustedUserCAKeys...)
|
||||
if guestConfig.LoginWebhook != nil {
|
||||
loginWebhook := *guestConfig.LoginWebhook
|
||||
payload.Latest.MetaData.LoginWebhook = &loginWebhook
|
||||
}
|
||||
}
|
||||
|
||||
return &firecracker.MMDSSpec{
|
||||
NetworkInterfaces: []string{"net0"},
|
||||
Version: firecracker.MMDSVersionV2,
|
||||
IPv4Address: defaultMMDSIPv4Address,
|
||||
Data: payload,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -48,9 +48,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
previousRecord := *record
|
||||
if record.Phase == contracthost.MachinePhaseRunning {
|
||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||
}
|
||||
if record.Phase == contracthost.MachinePhaseStarting {
|
||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseStopped {
|
||||
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
|
||||
}
|
||||
|
|
@ -82,7 +86,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path)
|
||||
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -94,39 +98,18 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports := defaultMachinePorts()
|
||||
if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||
if err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record.RuntimeHost = state.RuntimeHost
|
||||
record.TapDevice = state.TapName
|
||||
record.Ports = ports
|
||||
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||
record.Phase = contracthost.MachinePhaseRunning
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = ""
|
||||
record.Phase = contracthost.MachinePhaseStarting
|
||||
record.Error = ""
|
||||
record.PID = state.PID
|
||||
record.SocketPath = state.SocketPath
|
||||
record.StartedAt = state.StartedAt
|
||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
d.stopMachineRelays(id)
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||
d.stopMachineRelays(id)
|
||||
d.stopPublishedPortsForMachine(id)
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
_ = d.store.UpdateMachine(context.Background(), previousRecord)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -272,7 +255,10 @@ func (d *Daemon) listRunningNetworks(ctx context.Context, ignore contracthost.Ma
|
|||
|
||||
networks := make([]firecracker.NetworkAllocation, 0, len(records))
|
||||
for _, record := range records {
|
||||
if record.ID == ignore || record.Phase != contracthost.MachinePhaseRunning {
|
||||
if record.ID == ignore {
|
||||
continue
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseRunning && record.Phase != contracthost.MachinePhaseStarting {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(record.RuntimeHost) == "" || strings.TrimSpace(record.TapDevice) == "" {
|
||||
|
|
@ -337,11 +323,8 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Phase == contracthost.MachinePhaseRunning {
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||
if record.Phase == contracthost.MachinePhaseRunning || record.Phase == contracthost.MachinePhaseStarting {
|
||||
if _, err := d.reconcileMachine(ctx, machineID); err != nil {
|
||||
return err
|
||||
}
|
||||
return d.store.DeleteOperation(ctx, machineID)
|
||||
|
|
@ -385,7 +368,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseRunning {
|
||||
if record.Phase != contracthost.MachinePhaseRunning && record.Phase != contracthost.MachinePhaseStarting {
|
||||
return record, nil
|
||||
}
|
||||
|
||||
|
|
@ -393,6 +376,42 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if record.Phase == contracthost.MachinePhaseStarting {
|
||||
if state.Phase != firecracker.PhaseRunning {
|
||||
return d.failMachineStartup(ctx, record, state.Error)
|
||||
}
|
||||
ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ready {
|
||||
return record, nil
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||
if err != nil {
|
||||
return d.failMachineStartup(ctx, record, err.Error())
|
||||
}
|
||||
record.RuntimeHost = state.RuntimeHost
|
||||
record.TapDevice = state.TapName
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||
record.Phase = contracthost.MachinePhaseRunning
|
||||
record.Error = ""
|
||||
record.PID = state.PID
|
||||
record.SocketPath = state.SocketPath
|
||||
record.StartedAt = state.StartedAt
|
||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
return d.failMachineStartup(ctx, record, err.Error())
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||
d.stopMachineRelays(record.ID)
|
||||
return d.failMachineStartup(ctx, record, err.Error())
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
if state.Phase == firecracker.PhaseRunning {
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -418,6 +437,28 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
|||
return record, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRecord, failureReason string) (*model.MachineRecord, error) {
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("machine record is required")
|
||||
}
|
||||
_ = d.runtime.Delete(ctx, machineToRuntimeState(*record))
|
||||
d.stopMachineRelays(record.ID)
|
||||
d.stopPublishedPortsForMachine(record.ID)
|
||||
record.Phase = contracthost.MachinePhaseFailed
|
||||
record.Error = strings.TrimSpace(failureReason)
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = ""
|
||||
record.PID = 0
|
||||
record.SocketPath = ""
|
||||
record.RuntimeHost = ""
|
||||
record.TapDevice = ""
|
||||
record.StartedAt = nil
|
||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||
d.stopMachineRelays(record.ID)
|
||||
d.stopPublishedPortsForMachine(record.ID)
|
||||
|
|
@ -450,6 +491,11 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
|
|||
}
|
||||
|
||||
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||
if record.Phase == contracthost.MachinePhaseRunning && strings.TrimSpace(record.RuntimeHost) != "" {
|
||||
if err := d.syncGuestFilesystem(ctx, record.RuntimeHost); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: sync guest filesystem for %q failed before stop: %v\n", record.ID, err)
|
||||
}
|
||||
}
|
||||
d.stopMachineRelays(record.ID)
|
||||
d.stopPublishedPortsForMachine(record.ID)
|
||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contractho
|
|||
if record.ID == machineID {
|
||||
continue
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseRunning {
|
||||
continue
|
||||
}
|
||||
if port := machineRelayHostPort(record, name); port != 0 {
|
||||
used[port] = struct{}{}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,15 +30,15 @@ func waitForGuestReady(ctx context.Context, host string, ports []contracthost.Ma
|
|||
|
||||
func waitForGuestPort(ctx context.Context, host string, port contracthost.MachinePort) error {
|
||||
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
|
||||
dialer := net.Dialer{Timeout: defaultGuestDialTimeout}
|
||||
ticker := time.NewTicker(defaultGuestReadyPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastErr error
|
||||
for {
|
||||
connection, err := dialer.DialContext(ctx, string(port.Protocol), address)
|
||||
if err == nil {
|
||||
_ = connection.Close()
|
||||
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
|
||||
ready, err := guestPortReady(probeCtx, host, port)
|
||||
cancel()
|
||||
if err == nil && ready {
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
|
|
@ -50,3 +50,38 @@ func waitForGuestPort(ctx context.Context, host string, port contracthost.Machin
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false, fmt.Errorf("guest runtime host is required")
|
||||
}
|
||||
|
||||
for _, port := range ports {
|
||||
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
|
||||
ready, err := guestPortReady(probeCtx, host, port)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !ready {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func guestPortReady(ctx context.Context, host string, port contracthost.MachinePort) (bool, error) {
|
||||
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
|
||||
dialer := net.Dialer{Timeout: defaultGuestDialTimeout}
|
||||
|
||||
connection, err := dialer.DialContext(ctx, string(port.Protocol), address)
|
||||
if err == nil {
|
||||
_ = connection.Close()
|
||||
return true, nil
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return false, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@ package daemon
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -54,6 +57,19 @@ func (s machineLookupErrorStore) GetMachine(context.Context, contracthost.Machin
|
|||
return nil, s.err
|
||||
}
|
||||
|
||||
type relayExhaustionStore struct {
|
||||
hoststore.Store
|
||||
extraMachines []model.MachineRecord
|
||||
}
|
||||
|
||||
func (s relayExhaustionStore) ListMachines(ctx context.Context) ([]model.MachineRecord, error) {
|
||||
machines, err := s.Store.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(machines, s.extraMachines...), nil
|
||||
}
|
||||
|
||||
type publishedPortResult struct {
|
||||
response *contracthost.CreatePublishedPortResponse
|
||||
err error
|
||||
|
|
@ -283,6 +299,249 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
|||
assertOperationCount(t, baseStore, 1)
|
||||
}
|
||||
|
||||
func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
exhaustedStore := relayExhaustionStore{
|
||||
Store: baseStore,
|
||||
extraMachines: exhaustedMachineRelayRecords(),
|
||||
}
|
||||
|
||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||
defer func() {
|
||||
_ = sshListener.Close()
|
||||
}()
|
||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||
defer func() {
|
||||
_ = vncListener.Close()
|
||||
}()
|
||||
|
||||
startedAt := time.Unix(1700000200, 0).UTC()
|
||||
runtime := &fakeRuntime{
|
||||
bootState: firecracker.MachineState{
|
||||
ID: "vm-start",
|
||||
Phase: firecracker.PhaseRunning,
|
||||
PID: 9999,
|
||||
RuntimeHost: "127.0.0.1",
|
||||
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-start", "root", "run", "firecracker.sock"),
|
||||
TapName: "fctap-start",
|
||||
StartedAt: &startedAt,
|
||||
},
|
||||
}
|
||||
|
||||
hostDaemon, err := New(cfg, exhaustedStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
||||
kernelPath := filepath.Join(root, "artifact-kernel")
|
||||
rootFSPath := filepath.Join(root, "artifact-rootfs")
|
||||
systemVolumePath := filepath.Join(root, "machine-disks", "vm-start", "rootfs.ext4")
|
||||
for _, file := range []string{kernelPath, rootFSPath, systemVolumePath} {
|
||||
if err := os.MkdirAll(filepath.Dir(file), 0o755); err != nil {
|
||||
t.Fatalf("mkdir for %q: %v", file, err)
|
||||
}
|
||||
if err := os.WriteFile(file, []byte("payload"), 0o644); err != nil {
|
||||
t.Fatalf("write file %q: %v", file, err)
|
||||
}
|
||||
}
|
||||
if err := baseStore.PutArtifact(context.Background(), model.ArtifactRecord{
|
||||
Ref: artifactRef,
|
||||
LocalKey: "artifact",
|
||||
LocalDir: filepath.Join(root, "artifact"),
|
||||
KernelImagePath: kernelPath,
|
||||
RootFSPath: rootFSPath,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("put artifact: %v", err)
|
||||
}
|
||||
if err := baseStore.CreateVolume(context.Background(), model.VolumeRecord{
|
||||
ID: "vm-start-system",
|
||||
Kind: contracthost.VolumeKindSystem,
|
||||
Pool: model.StoragePoolMachineDisks,
|
||||
Path: systemVolumePath,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create system volume: %v", err)
|
||||
}
|
||||
if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
ID: "vm-start",
|
||||
Artifact: artifactRef,
|
||||
SystemVolumeID: "vm-start-system",
|
||||
Ports: defaultMachinePorts(),
|
||||
Phase: contracthost.MachinePhaseStopped,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create machine: %v", err)
|
||||
}
|
||||
|
||||
response, err := hostDaemon.StartMachine(context.Background(), "vm-start")
|
||||
if err != nil {
|
||||
t.Fatalf("StartMachine error = %v", err)
|
||||
}
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
|
||||
}
|
||||
|
||||
machine, err := baseStore.GetMachine(context.Background(), "vm-start")
|
||||
if err != nil {
|
||||
t.Fatalf("get machine: %v", err)
|
||||
}
|
||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting)
|
||||
}
|
||||
if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" {
|
||||
t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice)
|
||||
}
|
||||
if machine.PID != 9999 || machine.SocketPath == "" || machine.StartedAt == nil {
|
||||
t.Fatalf("machine process state missing: pid=%d socket=%q started_at=%v", machine.PID, machine.SocketPath, machine.StartedAt)
|
||||
}
|
||||
if len(runtime.deleteCalls) != 0 {
|
||||
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
exhaustedStore := relayExhaustionStore{
|
||||
Store: baseStore,
|
||||
extraMachines: exhaustedMachineRelayRecords(),
|
||||
}
|
||||
|
||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||
defer func() {
|
||||
_ = sshListener.Close()
|
||||
}()
|
||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||
defer func() {
|
||||
_ = vncListener.Close()
|
||||
}()
|
||||
|
||||
startedAt := time.Unix(1700000300, 0).UTC()
|
||||
runtime := &fakeRuntime{
|
||||
bootState: firecracker.MachineState{
|
||||
ID: "restored-exhausted",
|
||||
Phase: firecracker.PhaseRunning,
|
||||
PID: 8888,
|
||||
RuntimeHost: "127.0.0.1",
|
||||
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-exhausted", "root", "run", "firecracker.sock"),
|
||||
TapName: "fctap-restore",
|
||||
StartedAt: &startedAt,
|
||||
},
|
||||
}
|
||||
|
||||
hostDaemon, err := New(cfg, exhaustedStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
|
||||
|
||||
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
||||
kernelPath := filepath.Join(root, "artifact-kernel")
|
||||
rootFSPath := filepath.Join(root, "artifact-rootfs")
|
||||
if err := os.WriteFile(kernelPath, []byte("kernel"), 0o644); err != nil {
|
||||
t.Fatalf("write kernel: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(rootFSPath, []byte("rootfs"), 0o644); err != nil {
|
||||
t.Fatalf("write rootfs: %v", err)
|
||||
}
|
||||
if err := baseStore.PutArtifact(context.Background(), model.ArtifactRecord{
|
||||
Ref: artifactRef,
|
||||
LocalKey: "artifact",
|
||||
LocalDir: filepath.Join(root, "artifact"),
|
||||
KernelImagePath: kernelPath,
|
||||
RootFSPath: rootFSPath,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("put artifact: %v", err)
|
||||
}
|
||||
|
||||
snapDir := filepath.Join(root, "snapshots", "snap-exhausted")
|
||||
if err := os.MkdirAll(snapDir, 0o755); err != nil {
|
||||
t.Fatalf("create snapshot dir: %v", err)
|
||||
}
|
||||
snapDisk := filepath.Join(snapDir, "system.img")
|
||||
if err := os.WriteFile(snapDisk, []byte("disk"), 0o644); err != nil {
|
||||
t.Fatalf("write snapshot disk: %v", err)
|
||||
}
|
||||
memPath := filepath.Join(snapDir, "memory.bin")
|
||||
if err := os.WriteFile(memPath, []byte("mem"), 0o644); err != nil {
|
||||
t.Fatalf("write memory snapshot: %v", err)
|
||||
}
|
||||
statePath := filepath.Join(snapDir, "vmstate.bin")
|
||||
if err := os.WriteFile(statePath, []byte("state"), 0o644); err != nil {
|
||||
t.Fatalf("write vmstate snapshot: %v", err)
|
||||
}
|
||||
if err := baseStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
|
||||
ID: "snap-exhausted",
|
||||
MachineID: "source",
|
||||
Artifact: artifactRef,
|
||||
MemFilePath: memPath,
|
||||
StateFilePath: statePath,
|
||||
DiskPaths: []string{snapDisk},
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create snapshot: %v", err)
|
||||
}
|
||||
|
||||
server := newRestoreArtifactServer(t, map[string][]byte{
|
||||
"/kernel": []byte("kernel"),
|
||||
"/rootfs": []byte("rootfs"),
|
||||
"/memory": []byte("mem"),
|
||||
"/vmstate": []byte("state"),
|
||||
"/system": []byte("disk"),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored-exhausted",
|
||||
Artifact: contracthost.ArtifactRef{
|
||||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
},
|
||||
Snapshot: contracthost.DurableSnapshotSpec{
|
||||
SnapshotID: "snap-exhausted",
|
||||
MachineID: "source",
|
||||
ImageID: "image-1",
|
||||
Artifacts: []contracthost.SnapshotArtifact{
|
||||
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory"},
|
||||
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate"},
|
||||
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "allocate relay ports for restored machine") {
|
||||
t.Fatalf("RestoreSnapshot error = %v, want relay allocation failure", err)
|
||||
}
|
||||
|
||||
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); !errors.Is(err, hoststore.ErrNotFound) {
|
||||
t.Fatalf("restored system volume record should be deleted, get err = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); !os.IsNotExist(err) {
|
||||
t.Fatalf("restored system disk should be removed, stat err = %v", err)
|
||||
}
|
||||
if len(runtime.deleteCalls) != 1 {
|
||||
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
|
||||
}
|
||||
assertOperationCount(t, baseStore, 0)
|
||||
}
|
||||
|
||||
func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifacts(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
|
|
@ -339,6 +598,193 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact
|
|||
assertOperationCount(t, fileStore, 0)
|
||||
}
|
||||
|
||||
func TestStopMachineContinuesWhenGuestSyncFails(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
runtime := &fakeRuntime{}
|
||||
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
hostDaemon.syncGuestFilesystem = func(context.Context, string) error {
|
||||
return errors.New("guest sync failed")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
ID: "vm-stop-fail",
|
||||
SystemVolumeID: "vm-stop-fail-system",
|
||||
RuntimeHost: "172.16.0.2",
|
||||
TapDevice: "fctap-stop-fail",
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: 1234,
|
||||
SocketPath: filepath.Join(root, "runtime", "vm-stop-fail.sock"),
|
||||
Ports: defaultMachinePorts(),
|
||||
CreatedAt: now,
|
||||
StartedAt: &now,
|
||||
}); err != nil {
|
||||
t.Fatalf("create machine: %v", err)
|
||||
}
|
||||
|
||||
if err := hostDaemon.StopMachine(context.Background(), "vm-stop-fail"); err != nil {
|
||||
t.Fatalf("StopMachine returned error despite sync failure: %v", err)
|
||||
}
|
||||
if len(runtime.deleteCalls) != 1 {
|
||||
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
|
||||
}
|
||||
|
||||
machine, err := fileStore.GetMachine(context.Background(), "vm-stop-fail")
|
||||
if err != nil {
|
||||
t.Fatalf("get machine: %v", err)
|
||||
}
|
||||
if machine.Phase != contracthost.MachinePhaseStopped {
|
||||
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStopped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderedRestoredUserDiskArtifactsSortsByDriveIndex(t *testing.T) {
|
||||
ordered := orderedRestoredUserDiskArtifacts(map[string]restoredSnapshotArtifact{
|
||||
"user-10.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-10.img"}},
|
||||
"user-2.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-2.img"}},
|
||||
"user-1.img": {Artifact: contracthost.SnapshotArtifact{Name: "user-1.img"}},
|
||||
"system.img": {Artifact: contracthost.SnapshotArtifact{Name: "system.img"}},
|
||||
})
|
||||
|
||||
names := make([]string, 0, len(ordered))
|
||||
for _, artifact := range ordered {
|
||||
names = append(names, artifact.Artifact.Name)
|
||||
}
|
||||
if got, want := strings.Join(names, ","), "user-1.img,user-2.img,user-10.img"; got != want {
|
||||
t.Fatalf("ordered restored artifacts = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotCleansStagingArtifactsAfterSuccess(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||
defer func() { _ = sshListener.Close() }()
|
||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||
defer func() { _ = vncListener.Close() }()
|
||||
|
||||
startedAt := time.Unix(1700000400, 0).UTC()
|
||||
runtime := &fakeRuntime{
|
||||
bootState: firecracker.MachineState{
|
||||
ID: "restored-clean",
|
||||
Phase: firecracker.PhaseRunning,
|
||||
PID: 7777,
|
||||
RuntimeHost: "127.0.0.1",
|
||||
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "restored-clean", "root", "run", "firecracker.sock"),
|
||||
TapName: "fctap-clean",
|
||||
StartedAt: &startedAt,
|
||||
},
|
||||
}
|
||||
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
|
||||
|
||||
server := newRestoreArtifactServer(t, map[string][]byte{
|
||||
"/kernel": []byte("kernel"),
|
||||
"/rootfs": []byte("rootfs"),
|
||||
"/memory": []byte("mem"),
|
||||
"/vmstate": []byte("state"),
|
||||
"/system": []byte("disk"),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-clean", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored-clean",
|
||||
Artifact: contracthost.ArtifactRef{
|
||||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
},
|
||||
Snapshot: contracthost.DurableSnapshotSpec{
|
||||
SnapshotID: "snap-clean",
|
||||
MachineID: "source",
|
||||
ImageID: "image-1",
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
Artifacts: []contracthost.SnapshotArtifact{
|
||||
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))},
|
||||
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/vmstate", SHA256Hex: mustSHA256Hex(t, []byte("state"))},
|
||||
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system", SHA256Hex: mustSHA256Hex(t, []byte("disk"))},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RestoreSnapshot returned error: %v", err)
|
||||
}
|
||||
|
||||
stagingDir := filepath.Join(cfg.SnapshotsDir, "snap-clean", "restores", "restored-clean")
|
||||
if _, statErr := os.Stat(stagingDir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("restore staging dir should be cleaned up, stat err = %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
runtime := &fakeRuntime{}
|
||||
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
server := newRestoreArtifactServer(t, map[string][]byte{
|
||||
"/kernel": []byte("kernel"),
|
||||
"/rootfs": []byte("rootfs"),
|
||||
"/memory": []byte("mem"),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-fail-clean", contracthost.RestoreSnapshotRequest{
|
||||
MachineID: "restored-fail-clean",
|
||||
Artifact: contracthost.ArtifactRef{
|
||||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
},
|
||||
Snapshot: contracthost.DurableSnapshotSpec{
|
||||
SnapshotID: "snap-fail-clean",
|
||||
MachineID: "source",
|
||||
ImageID: "image-1",
|
||||
SourceRuntimeHost: "172.16.0.2",
|
||||
SourceTapDevice: "fctap0",
|
||||
Artifacts: []contracthost.SnapshotArtifact{
|
||||
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", DownloadURL: server.URL + "/memory", SHA256Hex: mustSHA256Hex(t, []byte("mem"))},
|
||||
{ID: "vmstate", Kind: contracthost.SnapshotArtifactKindVMState, Name: "vmstate.bin", DownloadURL: server.URL + "/missing"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "download durable snapshot artifacts") {
|
||||
t.Fatalf("RestoreSnapshot error = %v, want durable artifact download failure", err)
|
||||
}
|
||||
|
||||
stagingDir := filepath.Join(cfg.SnapshotsDir, "snap-fail-clean", "restores", "restored-fail-clean")
|
||||
if _, statErr := os.Stat(stagingDir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("restore staging dir should be cleaned up after download failure, stat err = %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
|
|
@ -423,6 +869,26 @@ func waitPublishedPortResult(t *testing.T, ch <-chan publishedPortResult) publis
|
|||
}
|
||||
}
|
||||
|
||||
func exhaustedMachineRelayRecords() []model.MachineRecord {
|
||||
count := int(maxMachineSSHRelayPort-minMachineSSHRelayPort) + 1
|
||||
machines := make([]model.MachineRecord, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
machines = append(machines, model.MachineRecord{
|
||||
ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)),
|
||||
Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i)),
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
})
|
||||
}
|
||||
return machines
|
||||
}
|
||||
|
||||
func mustSHA256Hex(t *testing.T, payload []byte) string {
|
||||
t.Helper()
|
||||
|
||||
sum := sha256.Sum256(payload)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func assertOperationCount(t *testing.T, store hoststore.Store, want int) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -108,6 +110,22 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
|
|||
return nil, fmt.Errorf("copy system disk: %w", err)
|
||||
}
|
||||
diskPaths = append(diskPaths, systemDiskTarget)
|
||||
for i, volumeID := range record.UserVolumeIDs {
|
||||
volume, err := d.store.GetVolume(ctx, volumeID)
|
||||
if err != nil {
|
||||
_ = d.runtime.Resume(ctx, runtimeState)
|
||||
_ = os.RemoveAll(snapshotDir)
|
||||
return nil, fmt.Errorf("get attached volume %q: %w", volumeID, err)
|
||||
}
|
||||
driveID := fmt.Sprintf("user-%d", i)
|
||||
targetPath := filepath.Join(snapshotDir, driveID+".img")
|
||||
if err := cowCopyFile(volume.Path, targetPath); err != nil {
|
||||
_ = d.runtime.Resume(ctx, runtimeState)
|
||||
_ = os.RemoveAll(snapshotDir)
|
||||
return nil, fmt.Errorf("copy attached volume %q: %w", volumeID, err)
|
||||
}
|
||||
diskPaths = append(diskPaths, targetPath)
|
||||
}
|
||||
|
||||
// Resume the source VM
|
||||
if err := d.runtime.Resume(ctx, runtimeState); err != nil {
|
||||
|
|
@ -132,6 +150,12 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
|
|||
return nil, fmt.Errorf("move vmstate file: %w", err)
|
||||
}
|
||||
|
||||
artifacts, err := buildSnapshotArtifacts(dstMemPath, dstStatePath, diskPaths)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(snapshotDir)
|
||||
return nil, fmt.Errorf("build snapshot artifacts: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
snapshotRecord := model.SnapshotRecord{
|
||||
ID: snapshotID,
|
||||
|
|
@ -140,6 +164,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
|
|||
MemFilePath: dstMemPath,
|
||||
StateFilePath: dstStatePath,
|
||||
DiskPaths: diskPaths,
|
||||
Artifacts: artifacts,
|
||||
SourceRuntimeHost: record.RuntimeHost,
|
||||
SourceTapDevice: record.TapDevice,
|
||||
CreatedAt: now,
|
||||
|
|
@ -151,25 +176,60 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
|
|||
|
||||
clearOperation = true
|
||||
return &contracthost.CreateSnapshotResponse{
|
||||
Snapshot: snapshotToContract(snapshotRecord),
|
||||
Snapshot: snapshotToContract(snapshotRecord),
|
||||
Artifacts: snapshotArtifactsToContract(snapshotRecord.Artifacts),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) UploadSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error) {
|
||||
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifactIndex := make(map[string]model.SnapshotArtifactRecord, len(snapshot.Artifacts))
|
||||
for _, artifact := range snapshot.Artifacts {
|
||||
artifactIndex[artifact.ID] = artifact
|
||||
}
|
||||
|
||||
response := &contracthost.UploadSnapshotResponse{
|
||||
Artifacts: make([]contracthost.UploadedSnapshotArtifact, 0, len(req.Artifacts)),
|
||||
}
|
||||
for _, upload := range req.Artifacts {
|
||||
artifact, ok := artifactIndex[upload.ArtifactID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID)
|
||||
}
|
||||
completedParts, err := uploadSnapshotArtifact(ctx, artifact.LocalPath, upload.Parts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err)
|
||||
}
|
||||
response.Artifacts = append(response.Artifacts, contracthost.UploadedSnapshotArtifact{
|
||||
ArtifactID: upload.ArtifactID,
|
||||
CompletedParts: completedParts,
|
||||
})
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) {
|
||||
if err := validateMachineID(req.MachineID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID {
|
||||
return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID)
|
||||
}
|
||||
if err := validateArtifactRef(req.Artifact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
unlock := d.lockMachine(req.MachineID)
|
||||
defer unlock()
|
||||
|
||||
snap, err := d.store.GetSnapshot(ctx, snapshotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := d.store.GetMachine(ctx, req.MachineID); err == nil {
|
||||
return nil, fmt.Errorf("machine %q already exists", req.MachineID)
|
||||
} else if err != nil && err != store.ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
|
||||
|
|
@ -188,55 +248,90 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
}
|
||||
}()
|
||||
|
||||
sourceMachine, err := d.store.GetMachine(ctx, snap.MachineID)
|
||||
switch {
|
||||
case err == nil && sourceMachine.Phase == contracthost.MachinePhaseRunning:
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("restore from snapshot %q while source machine %q is running is not supported yet", snapshotID, snap.MachineID)
|
||||
case err != nil && err != store.ErrNotFound:
|
||||
return nil, fmt.Errorf("get source machine for restore: %w", err)
|
||||
}
|
||||
|
||||
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
restoreNetwork, err := restoreNetworkFromSnapshot(snap)
|
||||
restoreNetwork, err := d.resolveRestoreNetwork(ctx, snapshotID, req.Snapshot)
|
||||
if err != nil {
|
||||
clearOperation = true
|
||||
return nil, err
|
||||
}
|
||||
if networkAllocationInUse(restoreNetwork, usedNetworks) {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("snapshot %q restore network %q (%s) is already in use", snapshotID, restoreNetwork.TapName, restoreNetwork.GuestIP())
|
||||
return nil, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
|
||||
}
|
||||
|
||||
artifact, err := d.store.GetArtifact(ctx, snap.Artifact)
|
||||
artifact, err := d.ensureArtifact(ctx, req.Artifact)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get artifact for restore: %w", err)
|
||||
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
|
||||
}
|
||||
|
||||
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
|
||||
restoredArtifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(stagingDir)
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("download durable snapshot artifacts: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(stagingDir) }()
|
||||
|
||||
// COW-copy system disk from snapshot to new machine's disk dir.
|
||||
newSystemDiskPath := d.systemVolumePath(req.MachineID)
|
||||
if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create machine disk dir: %w", err)
|
||||
}
|
||||
if len(snap.DiskPaths) < 1 {
|
||||
systemDiskPath, ok := restoredArtifacts["system.img"]
|
||||
if !ok {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("snapshot %q has no disk paths", snapshotID)
|
||||
return nil, fmt.Errorf("snapshot %q is missing system disk artifact", snapshotID)
|
||||
}
|
||||
if err := cowCopyFile(snap.DiskPaths[0], newSystemDiskPath); err != nil {
|
||||
memoryArtifact, ok := restoredArtifacts["memory.bin"]
|
||||
if !ok {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("snapshot %q is missing memory artifact", snapshotID)
|
||||
}
|
||||
vmstateArtifact, ok := restoredArtifacts["vmstate.bin"]
|
||||
if !ok {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("snapshot %q is missing vmstate artifact", snapshotID)
|
||||
}
|
||||
if err := cowCopyFile(systemDiskPath.LocalPath, newSystemDiskPath); err != nil {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("copy system disk for restore: %w", err)
|
||||
}
|
||||
|
||||
type restoredUserVolume struct {
|
||||
ID contracthost.VolumeID
|
||||
Path string
|
||||
DriveID string
|
||||
}
|
||||
restoredUserVolumes := make([]restoredUserVolume, 0)
|
||||
restoredDrivePaths := make(map[string]string)
|
||||
for _, restored := range orderedRestoredUserDiskArtifacts(restoredArtifacts) {
|
||||
name := restored.Artifact.Name
|
||||
driveID := strings.TrimSuffix(name, filepath.Ext(name))
|
||||
volumeID := contracthost.VolumeID(fmt.Sprintf("%s-%s", req.MachineID, driveID))
|
||||
volumePath := filepath.Join(d.config.MachineDisksDir, string(req.MachineID), name)
|
||||
if err := cowCopyFile(restored.LocalPath, volumePath); err != nil {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("copy restored drive %q: %w", driveID, err)
|
||||
}
|
||||
restoredUserVolumes = append(restoredUserVolumes, restoredUserVolume{
|
||||
ID: volumeID,
|
||||
Path: volumePath,
|
||||
DriveID: driveID,
|
||||
})
|
||||
restoredDrivePaths[driveID] = volumePath
|
||||
}
|
||||
|
||||
loadSpec := firecracker.SnapshotLoadSpec{
|
||||
ID: firecracker.MachineID(req.MachineID),
|
||||
SnapshotPath: snap.StateFilePath,
|
||||
MemFilePath: snap.MemFilePath,
|
||||
SnapshotPath: vmstateArtifact.LocalPath,
|
||||
MemFilePath: memoryArtifact.LocalPath,
|
||||
RootFSPath: newSystemDiskPath,
|
||||
KernelImagePath: artifact.KernelImagePath,
|
||||
DiskPaths: map[string]string{},
|
||||
DiskPaths: restoredDrivePaths,
|
||||
Network: &restoreNetwork,
|
||||
}
|
||||
|
||||
|
|
@ -275,18 +370,44 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
ID: systemVolumeID,
|
||||
Kind: contracthost.VolumeKindSystem,
|
||||
AttachedMachineID: machineIDPtr(req.MachineID),
|
||||
SourceArtifact: &snap.Artifact,
|
||||
SourceArtifact: &req.Artifact,
|
||||
Pool: model.StoragePoolMachineDisks,
|
||||
Path: newSystemDiskPath,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
_ = d.runtime.Delete(ctx, *machineState)
|
||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("create system volume record for restore: %w", err)
|
||||
}
|
||||
restoredUserVolumeIDs := make([]contracthost.VolumeID, 0, len(restoredUserVolumes))
|
||||
for _, volume := range restoredUserVolumes {
|
||||
if err := d.store.CreateVolume(ctx, model.VolumeRecord{
|
||||
ID: volume.ID,
|
||||
Kind: contracthost.VolumeKindUser,
|
||||
AttachedMachineID: machineIDPtr(req.MachineID),
|
||||
SourceArtifact: &req.Artifact,
|
||||
Pool: model.StoragePoolMachineDisks,
|
||||
Path: volume.Path,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
for _, restoredVolumeID := range restoredUserVolumeIDs {
|
||||
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
|
||||
}
|
||||
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
|
||||
_ = d.runtime.Delete(ctx, *machineState)
|
||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("create restored user volume record %q: %w", volume.ID, err)
|
||||
}
|
||||
restoredUserVolumeIDs = append(restoredUserVolumeIDs, volume.ID)
|
||||
}
|
||||
|
||||
machineRecord := model.MachineRecord{
|
||||
ID: req.MachineID,
|
||||
Artifact: snap.Artifact,
|
||||
Artifact: req.Artifact,
|
||||
SystemVolumeID: systemVolumeID,
|
||||
UserVolumeIDs: restoredUserVolumeIDs,
|
||||
RuntimeHost: machineState.RuntimeHost,
|
||||
TapDevice: machineState.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
|
|
@ -306,7 +427,14 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
d.relayAllocMu.Unlock()
|
||||
if err != nil {
|
||||
d.stopMachineRelays(machineRecord.ID)
|
||||
return nil, err
|
||||
for _, restoredVolumeID := range restoredUserVolumeIDs {
|
||||
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
|
||||
}
|
||||
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
|
||||
_ = d.runtime.Delete(ctx, *machineState)
|
||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("allocate relay ports for restored machine: %w", err)
|
||||
}
|
||||
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||
startedRelays := true
|
||||
|
|
@ -316,6 +444,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
}
|
||||
}()
|
||||
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
|
||||
for _, restoredVolumeID := range restoredUserVolumeIDs {
|
||||
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
|
||||
}
|
||||
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
|
||||
_ = d.runtime.Delete(ctx, *machineState)
|
||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
||||
clearOperation = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -360,12 +495,89 @@ func (d *Daemon) DeleteSnapshotByID(ctx context.Context, snapshotID contracthost
|
|||
|
||||
func snapshotToContract(record model.SnapshotRecord) contracthost.Snapshot {
|
||||
return contracthost.Snapshot{
|
||||
ID: record.ID,
|
||||
MachineID: record.MachineID,
|
||||
CreatedAt: record.CreatedAt,
|
||||
ID: record.ID,
|
||||
MachineID: record.MachineID,
|
||||
SourceRuntimeHost: record.SourceRuntimeHost,
|
||||
SourceTapDevice: record.SourceTapDevice,
|
||||
CreatedAt: record.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func snapshotArtifactsToContract(artifacts []model.SnapshotArtifactRecord) []contracthost.SnapshotArtifact {
|
||||
converted := make([]contracthost.SnapshotArtifact, 0, len(artifacts))
|
||||
for _, artifact := range artifacts {
|
||||
converted = append(converted, contracthost.SnapshotArtifact{
|
||||
ID: artifact.ID,
|
||||
Kind: artifact.Kind,
|
||||
Name: artifact.Name,
|
||||
SizeBytes: artifact.SizeBytes,
|
||||
SHA256Hex: artifact.SHA256Hex,
|
||||
})
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func orderedRestoredUserDiskArtifacts(artifacts map[string]restoredSnapshotArtifact) []restoredSnapshotArtifact {
|
||||
ordered := make([]restoredSnapshotArtifact, 0, len(artifacts))
|
||||
for name, artifact := range artifacts {
|
||||
if !strings.HasPrefix(name, "user-") || filepath.Ext(name) != ".img" {
|
||||
continue
|
||||
}
|
||||
ordered = append(ordered, artifact)
|
||||
}
|
||||
sort.Slice(ordered, func(i, j int) bool {
|
||||
iIdx, iOK := restoredUserDiskIndex(ordered[i].Artifact.Name)
|
||||
jIdx, jOK := restoredUserDiskIndex(ordered[j].Artifact.Name)
|
||||
switch {
|
||||
case iOK && jOK && iIdx != jIdx:
|
||||
return iIdx < jIdx
|
||||
case iOK != jOK:
|
||||
return iOK
|
||||
default:
|
||||
return ordered[i].Artifact.Name < ordered[j].Artifact.Name
|
||||
}
|
||||
})
|
||||
return ordered
|
||||
}
|
||||
|
||||
func restoredUserDiskIndex(name string) (int, bool) {
|
||||
if !strings.HasPrefix(name, "user-") || filepath.Ext(name) != ".img" {
|
||||
return 0, false
|
||||
}
|
||||
value := strings.TrimSuffix(strings.TrimPrefix(name, "user-"), filepath.Ext(name))
|
||||
index, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return index, true
|
||||
}
|
||||
|
||||
func (d *Daemon) resolveRestoreNetwork(ctx context.Context, snapshotID contracthost.SnapshotID, spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
|
||||
if network, err := restoreNetworkFromDurableSpec(spec); err == nil {
|
||||
return network, nil
|
||||
}
|
||||
|
||||
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
|
||||
if err == nil {
|
||||
return restoreNetworkFromSnapshot(snapshot)
|
||||
}
|
||||
if err != store.ErrNotFound {
|
||||
return firecracker.NetworkAllocation{}, err
|
||||
}
|
||||
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snapshotID)
|
||||
}
|
||||
|
||||
func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
|
||||
if strings.TrimSpace(spec.SourceRuntimeHost) == "" || strings.TrimSpace(spec.SourceTapDevice) == "" {
|
||||
return firecracker.NetworkAllocation{}, fmt.Errorf("durable snapshot spec is missing restore network metadata")
|
||||
}
|
||||
network, err := firecracker.AllocationFromGuestIP(spec.SourceRuntimeHost, spec.SourceTapDevice)
|
||||
if err != nil {
|
||||
return firecracker.NetworkAllocation{}, fmt.Errorf("reconstruct durable snapshot %q network: %w", spec.SnapshotID, err)
|
||||
}
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.NetworkAllocation, error) {
|
||||
if snap == nil {
|
||||
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot is required")
|
||||
|
|
|
|||
194
internal/daemon/snapshot_transfer.go
Normal file
194
internal/daemon/snapshot_transfer.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
)
|
||||
|
||||
type restoredSnapshotArtifact struct {
|
||||
Artifact contracthost.SnapshotArtifact
|
||||
LocalPath string
|
||||
}
|
||||
|
||||
func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) {
|
||||
artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2)
|
||||
|
||||
memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifacts = append(artifacts, memoryArtifact)
|
||||
|
||||
vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifacts = append(artifacts, vmstateArtifact)
|
||||
|
||||
for _, diskPath := range diskPaths {
|
||||
base := filepath.Base(diskPath)
|
||||
diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifacts = append(artifacts, diskArtifact)
|
||||
}
|
||||
|
||||
sort.Slice(artifacts, func(i, j int) bool {
|
||||
return artifacts[i].ID < artifacts[j].ID
|
||||
})
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
func snapshotArtifactRecord(id string, kind contracthost.SnapshotArtifactKind, name, path string) (model.SnapshotArtifactRecord, error) {
|
||||
size, err := fileSize(path)
|
||||
if err != nil {
|
||||
return model.SnapshotArtifactRecord{}, err
|
||||
}
|
||||
sum, err := sha256File(path)
|
||||
if err != nil {
|
||||
return model.SnapshotArtifactRecord{}, err
|
||||
}
|
||||
return model.SnapshotArtifactRecord{
|
||||
ID: id,
|
||||
Kind: kind,
|
||||
Name: name,
|
||||
LocalPath: path,
|
||||
SizeBytes: size,
|
||||
SHA256Hex: sum,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sha256File(path string) (string, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("open %q for sha256: %w", path, err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return "", fmt.Errorf("hash %q: %w", path, err)
|
||||
}
|
||||
return hex.EncodeToString(hash.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func uploadSnapshotArtifact(ctx context.Context, localPath string, parts []contracthost.SnapshotUploadPart) ([]contracthost.UploadedSnapshotPart, error) {
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("upload session has no parts")
|
||||
}
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open artifact %q: %w", localPath, err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
client := &http.Client{}
|
||||
completed := make([]contracthost.UploadedSnapshotPart, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
reader := io.NewSectionReader(file, part.OffsetBytes, part.SizeBytes)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, part.UploadURL, io.NopCloser(reader))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upload part %d: %w", part.PartNumber, err)
|
||||
}
|
||||
req.ContentLength = part.SizeBytes
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upload part %d: %w", part.PartNumber, err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("upload part %d returned %d", part.PartNumber, resp.StatusCode)
|
||||
}
|
||||
etag := strings.TrimSpace(resp.Header.Get("ETag"))
|
||||
if etag == "" {
|
||||
return nil, fmt.Errorf("upload part %d returned empty etag", part.PartNumber)
|
||||
}
|
||||
completed = append(completed, contracthost.UploadedSnapshotPart{
|
||||
PartNumber: part.PartNumber,
|
||||
ETag: etag,
|
||||
})
|
||||
}
|
||||
sort.Slice(completed, func(i, j int) bool {
|
||||
return completed[i].PartNumber < completed[j].PartNumber
|
||||
})
|
||||
return completed, nil
|
||||
}
|
||||
|
||||
func downloadDurableSnapshotArtifacts(ctx context.Context, root string, artifacts []contracthost.SnapshotArtifact) (map[string]restoredSnapshotArtifact, error) {
|
||||
if len(artifacts) == 0 {
|
||||
return nil, fmt.Errorf("restore snapshot is missing artifacts")
|
||||
}
|
||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create restore staging dir %q: %w", root, err)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
restored := make(map[string]restoredSnapshotArtifact, len(artifacts))
|
||||
for _, artifact := range artifacts {
|
||||
if strings.TrimSpace(artifact.DownloadURL) == "" {
|
||||
return nil, fmt.Errorf("artifact %q is missing download url", artifact.ID)
|
||||
}
|
||||
localPath := filepath.Join(root, artifact.Name)
|
||||
if err := downloadSnapshotArtifact(ctx, client, artifact.DownloadURL, localPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expectedSHA := strings.TrimSpace(artifact.SHA256Hex); expectedSHA != "" {
|
||||
actualSHA, err := sha256File(localPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.EqualFold(actualSHA, expectedSHA) {
|
||||
_ = os.Remove(localPath)
|
||||
return nil, fmt.Errorf("restore artifact %q sha256 mismatch: got %s want %s", artifact.Name, actualSHA, expectedSHA)
|
||||
}
|
||||
}
|
||||
restored[artifact.Name] = restoredSnapshotArtifact{
|
||||
Artifact: artifact,
|
||||
LocalPath: localPath,
|
||||
}
|
||||
}
|
||||
return restored, nil
|
||||
}
|
||||
|
||||
func downloadSnapshotArtifact(ctx context.Context, client *http.Client, sourceURL, targetPath string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build restore download request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download durable snapshot artifact: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("download durable snapshot artifact returned %d", resp.StatusCode)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
return fmt.Errorf("create restore artifact dir %q: %w", filepath.Dir(targetPath), err)
|
||||
}
|
||||
out, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create restore artifact %q: %w", targetPath, err)
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||
return fmt.Errorf("write restore artifact %q: %w", targetPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
60
internal/daemon/snapshot_transfer_test.go
Normal file
60
internal/daemon/snapshot_transfer_test.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
)
|
||||
|
||||
func TestUploadSnapshotArtifactRejectsEmptyETag(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
t.Fatalf("unexpected method %q", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
artifactPath := filepath.Join(t.TempDir(), "artifact.bin")
|
||||
if err := os.WriteFile(artifactPath, []byte("payload"), 0o644); err != nil {
|
||||
t.Fatalf("write artifact: %v", err)
|
||||
}
|
||||
|
||||
_, err := uploadSnapshotArtifact(context.Background(), artifactPath, []contracthost.SnapshotUploadPart{{
|
||||
PartNumber: 1,
|
||||
OffsetBytes: 0,
|
||||
SizeBytes: int64(len("payload")),
|
||||
UploadURL: server.URL,
|
||||
}})
|
||||
if err == nil || !strings.Contains(err.Error(), "empty etag") {
|
||||
t.Fatalf("uploadSnapshotArtifact error = %v, want empty etag failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadDurableSnapshotArtifactsRejectsSHA256Mismatch(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("payload"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
root := t.TempDir()
|
||||
_, err := downloadDurableSnapshotArtifacts(context.Background(), root, []contracthost.SnapshotArtifact{{
|
||||
ID: "memory",
|
||||
Kind: contracthost.SnapshotArtifactKindMemory,
|
||||
Name: "memory.bin",
|
||||
DownloadURL: server.URL,
|
||||
SHA256Hex: strings.Repeat("0", 64),
|
||||
}})
|
||||
if err == nil || !strings.Contains(err.Error(), "sha256 mismatch") {
|
||||
t.Fatalf("downloadDurableSnapshotArtifacts error = %v, want sha256 mismatch", err)
|
||||
}
|
||||
if _, statErr := os.Stat(filepath.Join(root, "memory.bin")); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("corrupt artifact should be removed, stat err = %v", statErr)
|
||||
}
|
||||
}
|
||||
|
|
@ -52,19 +52,27 @@ func (d *Daemon) GetStorageReport(ctx context.Context) (*contracthost.GetStorage
|
|||
}
|
||||
}
|
||||
|
||||
machineUsage := make([]contracthost.MachineStorageUsage, 0, len(volumes))
|
||||
machineUsageByID := make(map[contracthost.MachineID]contracthost.MachineStorageUsage)
|
||||
for _, volume := range volumes {
|
||||
if volume.AttachedMachineID == nil || volume.Kind != contracthost.VolumeKindSystem {
|
||||
if volume.AttachedMachineID == nil {
|
||||
continue
|
||||
}
|
||||
bytes, err := fileSize(volume.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
machineUsage = append(machineUsage, contracthost.MachineStorageUsage{
|
||||
MachineID: *volume.AttachedMachineID,
|
||||
SystemBytes: bytes,
|
||||
})
|
||||
usage := machineUsageByID[*volume.AttachedMachineID]
|
||||
usage.MachineID = *volume.AttachedMachineID
|
||||
if volume.Kind == contracthost.VolumeKindSystem {
|
||||
usage.SystemBytes += bytes
|
||||
} else {
|
||||
usage.UserBytes += bytes
|
||||
}
|
||||
machineUsageByID[*volume.AttachedMachineID] = usage
|
||||
}
|
||||
machineUsage := make([]contracthost.MachineStorageUsage, 0, len(machineUsageByID))
|
||||
for _, usage := range machineUsageByID {
|
||||
machineUsage = append(machineUsage, usage)
|
||||
}
|
||||
|
||||
snapshotUsage := make([]contracthost.SnapshotStorageUsage, 0, len(snapshots))
|
||||
|
|
|
|||
|
|
@ -27,10 +27,12 @@ type bootSourceRequest struct {
|
|||
}
|
||||
|
||||
type driveRequest struct {
|
||||
DriveID string `json:"drive_id"`
|
||||
IsReadOnly bool `json:"is_read_only"`
|
||||
IsRootDevice bool `json:"is_root_device"`
|
||||
PathOnHost string `json:"path_on_host"`
|
||||
DriveID string `json:"drive_id"`
|
||||
IsReadOnly bool `json:"is_read_only"`
|
||||
IsRootDevice bool `json:"is_root_device"`
|
||||
PathOnHost string `json:"path_on_host"`
|
||||
CacheType DriveCacheType `json:"cache_type,omitempty"`
|
||||
IOEngine DriveIOEngine `json:"io_engine,omitempty"`
|
||||
}
|
||||
|
||||
type entropyRequest struct{}
|
||||
|
|
@ -58,6 +60,13 @@ type networkInterfaceRequest struct {
|
|||
IfaceID string `json:"iface_id"`
|
||||
}
|
||||
|
||||
type mmdsConfigRequest struct {
|
||||
IPv4Address string `json:"ipv4_address,omitempty"`
|
||||
NetworkInterfaces []string `json:"network_interfaces"`
|
||||
Version MMDSVersion `json:"version,omitempty"`
|
||||
IMDSCompat bool `json:"imds_compat,omitempty"`
|
||||
}
|
||||
|
||||
type serialRequest struct {
|
||||
SerialOutPath string `json:"serial_out_path"`
|
||||
}
|
||||
|
|
@ -127,6 +136,24 @@ func (c *apiClient) PutNetworkInterface(ctx context.Context, network NetworkAllo
|
|||
return c.do(ctx, http.MethodPut, endpoint, body, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (c *apiClient) PutMMDSConfig(ctx context.Context, spec MMDSSpec) error {
|
||||
body := mmdsConfigRequest{
|
||||
IPv4Address: strings.TrimSpace(spec.IPv4Address),
|
||||
NetworkInterfaces: append([]string(nil), spec.NetworkInterfaces...),
|
||||
Version: spec.Version,
|
||||
IMDSCompat: spec.IMDSCompat,
|
||||
}
|
||||
return c.do(ctx, http.MethodPut, "/mmds/config", body, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (c *apiClient) PutMMDS(ctx context.Context, data any) error {
|
||||
return c.do(ctx, http.MethodPut, "/mmds", data, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (c *apiClient) PatchMMDS(ctx context.Context, data any) error {
|
||||
return c.do(ctx, http.MethodPatch, "/mmds", data, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (c *apiClient) PutSerial(ctx context.Context, serialOutPath string) error {
|
||||
return c.do(
|
||||
ctx,
|
||||
|
|
|
|||
|
|
@ -83,6 +83,91 @@ func TestConfigureMachineEnablesEntropyAndSerialBeforeStart(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConfigureMachineConfiguresMMDSBeforeStart(t *testing.T) {
|
||||
var requests []capturedRequest
|
||||
|
||||
socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
requests = append(requests, capturedRequest{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Body: string(body),
|
||||
})
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
defer shutdown()
|
||||
|
||||
client := newAPIClient(socketPath)
|
||||
spec := MachineSpec{
|
||||
ID: "vm-2",
|
||||
VCPUs: 1,
|
||||
MemoryMiB: 512,
|
||||
KernelImagePath: "/kernel",
|
||||
RootFSPath: "/rootfs",
|
||||
RootDrive: DriveSpec{
|
||||
ID: "root_drive",
|
||||
Path: "/rootfs",
|
||||
CacheType: DriveCacheTypeUnsafe,
|
||||
IOEngine: DriveIOEngineSync,
|
||||
},
|
||||
MMDS: &MMDSSpec{
|
||||
NetworkInterfaces: []string{"net0"},
|
||||
Version: MMDSVersionV2,
|
||||
IPv4Address: "169.254.169.254",
|
||||
Data: map[string]any{
|
||||
"latest": map[string]any{
|
||||
"meta-data": map[string]any{
|
||||
"microagent": map[string]any{"hostname": "vm-2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths := machinePaths{JailedSerialLogPath: "/logs/serial.log"}
|
||||
network := NetworkAllocation{
|
||||
InterfaceID: defaultInterfaceID,
|
||||
TapName: "fctap0",
|
||||
GuestMAC: "06:00:ac:10:00:02",
|
||||
}
|
||||
|
||||
if err := configureMachine(context.Background(), client, paths, spec, network); err != nil {
|
||||
t.Fatalf("configure machine: %v", err)
|
||||
}
|
||||
|
||||
gotPaths := make([]string, 0, len(requests))
|
||||
for _, request := range requests {
|
||||
gotPaths = append(gotPaths, request.Path)
|
||||
}
|
||||
wantPaths := []string{
|
||||
"/machine-config",
|
||||
"/boot-source",
|
||||
"/drives/root_drive",
|
||||
"/network-interfaces/net0",
|
||||
"/mmds/config",
|
||||
"/mmds",
|
||||
"/entropy",
|
||||
"/serial",
|
||||
"/actions",
|
||||
}
|
||||
if len(gotPaths) != len(wantPaths) {
|
||||
t.Fatalf("request count mismatch: got %d want %d (%v)", len(gotPaths), len(wantPaths), gotPaths)
|
||||
}
|
||||
for i := range wantPaths {
|
||||
if gotPaths[i] != wantPaths[i] {
|
||||
t.Fatalf("request %d mismatch: got %q want %q", i, gotPaths[i], wantPaths[i])
|
||||
}
|
||||
}
|
||||
if requests[2].Body != "{\"drive_id\":\"root_drive\",\"is_read_only\":false,\"is_root_device\":true,\"path_on_host\":\"/rootfs\",\"cache_type\":\"Unsafe\",\"io_engine\":\"Sync\"}" {
|
||||
t.Fatalf("root drive body mismatch: got %q", requests[2].Body)
|
||||
}
|
||||
if requests[4].Body != "{\"ipv4_address\":\"169.254.169.254\",\"network_interfaces\":[\"net0\"],\"version\":\"V2\"}" {
|
||||
t.Fatalf("mmds config body mismatch: got %q", requests[4].Body)
|
||||
}
|
||||
}
|
||||
|
||||
func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,16 @@ func configureMachine(ctx context.Context, client *apiClient, paths machinePaths
|
|||
if err := client.PutNetworkInterface(ctx, network); err != nil {
|
||||
return fmt.Errorf("put network interface: %w", err)
|
||||
}
|
||||
if spec.MMDS != nil {
|
||||
if err := client.PutMMDSConfig(ctx, *spec.MMDS); err != nil {
|
||||
return fmt.Errorf("put mmds config: %w", err)
|
||||
}
|
||||
if spec.MMDS.Data != nil {
|
||||
if err := client.PutMMDS(ctx, spec.MMDS.Data); err != nil {
|
||||
return fmt.Errorf("put mmds payload: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := client.PutEntropy(ctx); err != nil {
|
||||
return fmt.Errorf("put entropy device: %w", err)
|
||||
}
|
||||
|
|
@ -97,12 +107,14 @@ func stageMachineFiles(spec MachineSpec, paths machinePaths) (MachineSpec, error
|
|||
|
||||
rootFSPath, err := stagedFileName(spec.RootFSPath)
|
||||
if err != nil {
|
||||
return MachineSpec{}, fmt.Errorf("rootfs path: %w", err)
|
||||
return MachineSpec{}, fmt.Errorf("root drive path: %w", err)
|
||||
}
|
||||
if err := linkMachineFile(spec.RootFSPath, filepath.Join(paths.ChrootRootDir, rootFSPath)); err != nil {
|
||||
return MachineSpec{}, fmt.Errorf("link rootfs into jail: %w", err)
|
||||
return MachineSpec{}, fmt.Errorf("link root drive into jail: %w", err)
|
||||
}
|
||||
staged.RootFSPath = rootFSPath
|
||||
staged.RootDrive = spec.rootDrive()
|
||||
staged.RootDrive.Path = rootFSPath
|
||||
|
||||
staged.Drives = make([]DriveSpec, len(spec.Drives))
|
||||
for i, drive := range spec.Drives {
|
||||
|
|
@ -174,6 +186,8 @@ func additionalDriveRequests(spec MachineSpec) []driveRequest {
|
|||
IsReadOnly: drive.ReadOnly,
|
||||
IsRootDevice: false,
|
||||
PathOnHost: drive.Path,
|
||||
CacheType: drive.CacheType,
|
||||
IOEngine: drive.IOEngine,
|
||||
})
|
||||
}
|
||||
return requests
|
||||
|
|
@ -249,11 +263,14 @@ func linkMachineFile(source string, target string) error {
|
|||
}
|
||||
|
||||
func rootDriveRequest(spec MachineSpec) driveRequest {
|
||||
root := spec.rootDrive()
|
||||
return driveRequest{
|
||||
DriveID: defaultRootDriveID,
|
||||
IsReadOnly: false,
|
||||
DriveID: root.ID,
|
||||
IsReadOnly: root.ReadOnly,
|
||||
IsRootDevice: true,
|
||||
PathOnHost: spec.RootFSPath,
|
||||
PathOnHost: root.Path,
|
||||
CacheType: root.CacheType,
|
||||
IOEngine: root.IOEngine,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,16 +16,50 @@ type MachineSpec struct {
|
|||
MemoryMiB int64
|
||||
KernelImagePath string
|
||||
RootFSPath string
|
||||
RootDrive DriveSpec
|
||||
KernelArgs string
|
||||
Drives []DriveSpec
|
||||
MMDS *MMDSSpec
|
||||
Vsock *VsockSpec
|
||||
}
|
||||
|
||||
// DriveSpec describes an additional guest block device.
|
||||
type DriveSpec struct {
|
||||
ID string
|
||||
Path string
|
||||
ReadOnly bool
|
||||
ID string
|
||||
Path string
|
||||
ReadOnly bool
|
||||
CacheType DriveCacheType
|
||||
IOEngine DriveIOEngine
|
||||
}
|
||||
|
||||
type DriveCacheType string
|
||||
|
||||
const (
|
||||
DriveCacheTypeUnsafe DriveCacheType = "Unsafe"
|
||||
DriveCacheTypeWriteback DriveCacheType = "Writeback"
|
||||
)
|
||||
|
||||
type DriveIOEngine string
|
||||
|
||||
const (
|
||||
DriveIOEngineSync DriveIOEngine = "Sync"
|
||||
DriveIOEngineAsync DriveIOEngine = "Async"
|
||||
)
|
||||
|
||||
type MMDSVersion string
|
||||
|
||||
const (
|
||||
MMDSVersionV1 MMDSVersion = "V1"
|
||||
MMDSVersionV2 MMDSVersion = "V2"
|
||||
)
|
||||
|
||||
// MMDSSpec describes the MMDS network configuration and initial payload.
|
||||
type MMDSSpec struct {
|
||||
NetworkInterfaces []string
|
||||
Version MMDSVersion
|
||||
IPv4Address string
|
||||
IMDSCompat bool
|
||||
Data any
|
||||
}
|
||||
|
||||
// VsockSpec describes a single host-guest vsock device.
|
||||
|
|
@ -49,17 +83,22 @@ func (s MachineSpec) Validate() error {
|
|||
if strings.TrimSpace(s.KernelImagePath) == "" {
|
||||
return fmt.Errorf("machine kernel image path is required")
|
||||
}
|
||||
if strings.TrimSpace(s.RootFSPath) == "" {
|
||||
return fmt.Errorf("machine rootfs path is required")
|
||||
}
|
||||
if filepath.Base(strings.TrimSpace(string(s.ID))) != strings.TrimSpace(string(s.ID)) {
|
||||
return fmt.Errorf("machine id %q must not contain path separators", s.ID)
|
||||
}
|
||||
if err := s.rootDrive().Validate(); err != nil {
|
||||
return fmt.Errorf("root drive: %w", err)
|
||||
}
|
||||
for i, drive := range s.Drives {
|
||||
if err := drive.Validate(); err != nil {
|
||||
return fmt.Errorf("drive %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
if s.MMDS != nil {
|
||||
if err := s.MMDS.Validate(); err != nil {
|
||||
return fmt.Errorf("mmds: %w", err)
|
||||
}
|
||||
}
|
||||
if s.Vsock != nil {
|
||||
if err := s.Vsock.Validate(); err != nil {
|
||||
return fmt.Errorf("vsock: %w", err)
|
||||
|
|
@ -70,11 +109,39 @@ func (s MachineSpec) Validate() error {
|
|||
|
||||
// Validate reports whether the drive specification is usable.
|
||||
func (d DriveSpec) Validate() error {
|
||||
if strings.TrimSpace(d.Path) == "" {
|
||||
return fmt.Errorf("drive path is required")
|
||||
}
|
||||
if strings.TrimSpace(d.ID) == "" {
|
||||
return fmt.Errorf("drive id is required")
|
||||
}
|
||||
if strings.TrimSpace(d.Path) == "" {
|
||||
return fmt.Errorf("drive path is required")
|
||||
switch d.CacheType {
|
||||
case "", DriveCacheTypeUnsafe, DriveCacheTypeWriteback:
|
||||
default:
|
||||
return fmt.Errorf("unsupported drive cache type %q", d.CacheType)
|
||||
}
|
||||
switch d.IOEngine {
|
||||
case "", DriveIOEngineSync, DriveIOEngineAsync:
|
||||
default:
|
||||
return fmt.Errorf("unsupported drive io engine %q", d.IOEngine)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate reports whether the MMDS configuration is usable.
|
||||
func (m MMDSSpec) Validate() error {
|
||||
if len(m.NetworkInterfaces) == 0 {
|
||||
return fmt.Errorf("mmds network interfaces are required")
|
||||
}
|
||||
switch m.Version {
|
||||
case "", MMDSVersionV1, MMDSVersionV2:
|
||||
default:
|
||||
return fmt.Errorf("unsupported mmds version %q", m.Version)
|
||||
}
|
||||
for i, iface := range m.NetworkInterfaces {
|
||||
if strings.TrimSpace(iface) == "" {
|
||||
return fmt.Errorf("mmds network_interfaces[%d] is required", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -92,3 +159,14 @@ func (v VsockSpec) Validate() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s MachineSpec) rootDrive() DriveSpec {
|
||||
root := s.RootDrive
|
||||
if strings.TrimSpace(root.ID) == "" {
|
||||
root.ID = defaultRootDriveID
|
||||
}
|
||||
if strings.TrimSpace(root.Path) == "" {
|
||||
root.Path = s.RootFSPath
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ type Service interface {
|
|||
Health(context.Context) (*contracthost.HealthResponse, error)
|
||||
GetStorageReport(context.Context) (*contracthost.GetStorageReportResponse, error)
|
||||
CreateSnapshot(context.Context, contracthost.MachineID, contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error)
|
||||
UploadSnapshot(context.Context, contracthost.SnapshotID, contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error)
|
||||
ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error)
|
||||
GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error)
|
||||
DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error
|
||||
|
|
@ -278,6 +279,25 @@ func (h *Handler) handleSnapshot(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if len(parts) == 2 && parts[1] == "upload" {
|
||||
if r.Method != http.MethodPost {
|
||||
writeMethodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var req contracthost.UploadSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
response, err := h.service.UploadSnapshot(r.Context(), snapshotID, req)
|
||||
if err != nil {
|
||||
writeError(w, statusForError(err), err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, response)
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusNotFound, fmt.Errorf("route not found"))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ type ArtifactRecord struct {
|
|||
type MachineRecord struct {
|
||||
ID contracthost.MachineID
|
||||
Artifact contracthost.ArtifactRef
|
||||
GuestConfig *contracthost.GuestConfig
|
||||
SystemVolumeID contracthost.VolumeID
|
||||
UserVolumeIDs []contracthost.VolumeID
|
||||
RuntimeHost string
|
||||
|
|
@ -71,11 +72,21 @@ type SnapshotRecord struct {
|
|||
MemFilePath string
|
||||
StateFilePath string
|
||||
DiskPaths []string
|
||||
Artifacts []SnapshotArtifactRecord
|
||||
SourceRuntimeHost string
|
||||
SourceTapDevice string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type SnapshotArtifactRecord struct {
|
||||
ID string
|
||||
Kind contracthost.SnapshotArtifactKind
|
||||
Name string
|
||||
LocalPath string
|
||||
SizeBytes int64
|
||||
SHA256Hex string
|
||||
}
|
||||
|
||||
type PublishedPortRecord struct {
|
||||
ID contracthost.PublishedPortID
|
||||
MachineID contracthost.MachineID
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue