diff --git a/internal/daemon/create.go b/internal/daemon/create.go index 7777e9c..b5c54f3 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -69,15 +69,6 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi if err := cloneDiskFile(artifact.RootFSPath, systemVolumePath, d.config.DiskCloneMode); err != nil { return nil, fmt.Errorf("clone rootfs for %q: %w", req.MachineID, err) } - if err := os.Truncate(systemVolumePath, defaultGuestDiskSizeBytes); err != nil { - return nil, fmt.Errorf("expand system volume for %q: %w", req.MachineID, err) - } - if err := injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil { - return nil, fmt.Errorf("inject machine identity for %q: %w", req.MachineID, err) - } - if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil { - return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err) - } removeSystemVolumeOnFailure := true defer func() { if !removeSystemVolumeOnFailure { @@ -86,6 +77,15 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi _ = os.Remove(systemVolumePath) _ = os.RemoveAll(filepath.Dir(systemVolumePath)) }() + if err := os.Truncate(systemVolumePath, defaultGuestDiskSizeBytes); err != nil { + return nil, fmt.Errorf("expand system volume for %q: %w", req.MachineID, err) + } + if err := d.injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil { + return nil, fmt.Errorf("inject machine identity for %q: %w", req.MachineID, err) + } + if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil { + return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err) + } spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig) if err != nil { diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 3fc4e88..64190b2 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -46,6 +46,8 @@ type Daemon struct { reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error readGuestSSHPublicKey func(context.Context, string) (string, error) + injectMachineIdentity func(context.Context, string, contracthost.MachineID) error + injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error syncGuestFilesystem func(context.Context, string) error shutdownGuest func(context.Context, string) error personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error @@ -86,6 +88,8 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err runtime: runtime, reconfigureGuestIdentity: nil, readGuestSSHPublicKey: nil, + injectMachineIdentity: nil, + injectGuestConfig: nil, personalizeGuest: nil, machineLocks: make(map[contracthost.MachineID]*sync.Mutex), artifactLocks: make(map[string]*sync.Mutex), @@ -94,6 +98,8 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err } daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH daemon.readGuestSSHPublicKey = readGuestSSHPublicKey + daemon.injectMachineIdentity = injectMachineIdentity + daemon.injectGuestConfig = injectGuestConfig daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH daemon.shutdownGuest = daemon.issueGuestPoweroff daemon.personalizeGuest = daemon.personalizeGuestConfig diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index d2f9929..c0d9a60 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -263,6 +263,48 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { } } +func TestCreateMachineCleansSystemVolumeOnInjectFailure(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, fileStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + hostDaemon.injectMachineIdentity = func(context.Context, string, contracthost.MachineID) error { + return errors.New("inject failed") + } + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel-image"), + "/rootfs": buildTestExt4ImageBytes(t), + }) + defer server.Close() + + _, err = hostDaemon.CreateMachine(context.Background(), contracthost.CreateMachineRequest{ + MachineID: "vm-inject-fail", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + }) + if err == nil || !strings.Contains(err.Error(), "inject machine identity") { + t.Fatalf("CreateMachine error = %v, want inject machine identity failure", err) + } + + systemVolumePath := hostDaemon.systemVolumePath("vm-inject-fail") + if _, statErr := os.Stat(systemVolumePath); !os.IsNotExist(statErr) { + t.Fatalf("system volume should be cleaned up, stat err = %v", statErr) + } + if _, statErr := os.Stat(filepath.Dir(systemVolumePath)); !os.IsNotExist(statErr) { + t.Fatalf("system volume dir should be cleaned up, stat err = %v", statErr) + } +} + func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -337,7 +379,7 @@ func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) { } } -func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) { +func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) { root := t.TempDir() cfg := testConfig(root) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -386,10 +428,6 @@ func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) { t.Fatalf("create machine: %v", err) } - if err := hostDaemon.Reconcile(context.Background()); err != nil { - t.Fatalf("Reconcile returned error: %v", err) - } - response, err := hostDaemon.GetMachine(context.Background(), "vm-starting") if err != nil { t.Fatalf("GetMachine returned error: %v", err) @@ -515,6 +553,92 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) { } } +func TestShutdownGuestCleanChecksGuestStateAfterPoweroffTimeout(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{ + inspectOverride: func(state firecracker.MachineState) (*firecracker.MachineState, error) { + state.Phase = firecracker.PhaseStopped + state.PID = 0 + return &state, nil + }, + } + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + hostDaemon.shutdownGuest = func(context.Context, string) error { + return context.DeadlineExceeded + } + + now := time.Now().UTC() + record := &model.MachineRecord{ + ID: "vm-timeout", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap-timeout", + Phase: contracthost.MachinePhaseRunning, + PID: 1234, + SocketPath: filepath.Join(root, "runtime", "vm-timeout.sock"), + Ports: defaultMachinePorts(), + CreatedAt: now, + StartedAt: &now, + } + + if ok := hostDaemon.shutdownGuestClean(context.Background(), record); !ok { + t.Fatal("shutdownGuestClean should treat a timed-out poweroff as success when the VM is already stopped") + } +} + +func TestShutdownGuestCleanRespectsContextCancellation(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{ + inspectOverride: func(state firecracker.MachineState) (*firecracker.MachineState, error) { + state.Phase = firecracker.PhaseRunning + return &state, nil + }, + } + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + hostDaemon.shutdownGuest = func(context.Context, string) error { return nil } + + now := time.Now().UTC() + record := &model.MachineRecord{ + ID: "vm-cancel", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap-cancel", + Phase: contracthost.MachinePhaseRunning, + PID: 1234, + SocketPath: filepath.Join(root, "runtime", "vm-cancel.sock"), + Ports: defaultMachinePorts(), + CreatedAt: now, + StartedAt: &now, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + if ok := hostDaemon.shutdownGuestClean(ctx, record); ok { + t.Fatal("shutdownGuestClean should not report a clean shutdown after cancellation") + } + if elapsed := time.Since(start); elapsed > time.Second { + t.Fatalf("shutdownGuestClean took %v after cancellation, want fast return", elapsed) + } +} + func TestNewEnsuresBackendSSHKeyPair(t *testing.T) { root := t.TempDir() cfg := testConfig(root) diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index b96ce7c..beef737 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -10,14 +10,14 @@ import ( "strings" "time" - contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" ) func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) { - record, err := d.store.GetMachine(ctx, id) + record, err := d.reconcileMachine(ctx, id) if err != nil { return nil, err } @@ -527,26 +527,37 @@ func (d *Daemon) shutdownGuestClean(ctx context.Context, record *model.MachineRe defer cancel() if err := d.shutdownGuest(shutdownCtx, record.RuntimeHost); err != nil { - fmt.Fprintf(os.Stderr, "warning: guest poweroff for %q failed: %v\n", record.ID, err) - return false + if ctx.Err() != nil { + return false + } + if shutdownCtx.Err() == nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "warning: guest poweroff for %q failed: %v\n", record.ID, err) + return false + } + fmt.Fprintf(os.Stderr, "warning: guest poweroff for %q timed out before confirmation; checking whether shutdown is already in progress: %v\n", record.ID, err) } - deadline := time.After(defaultGuestStopTimeout) ticker := time.NewTicker(250 * time.Millisecond) defer ticker.Stop() for { + state, err := d.runtime.Inspect(machineToRuntimeState(*record)) + if err != nil { + return false + } + if state.Phase != firecracker.PhaseRunning { + return true + } + select { - case <-deadline: + case <-ctx.Done(): + return false + case <-shutdownCtx.Done(): + if ctx.Err() != nil { + return false + } fmt.Fprintf(os.Stderr, "warning: guest %q did not exit within stop window; forcing teardown\n", record.ID) return false case <-ticker.C: - state, err := d.runtime.Inspect(machineToRuntimeState(*record)) - if err != nil { - return false - } - if state.Phase != firecracker.PhaseRunning { - return true - } } } } diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 5df2c00..b346706 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -773,6 +773,58 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T) } } +func TestRestoreSnapshotCleansMachineDiskDirOnInjectFailure(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + hostDaemon.injectMachineIdentity = func(context.Context, string, contracthost.MachineID) error { + return errors.New("inject failed") + } + + server := newRestoreArtifactServer(t, map[string][]byte{ + "/kernel": []byte("kernel"), + "/rootfs": []byte("rootfs"), + "/system": buildTestExt4ImageBytes(t), + }) + defer server.Close() + + _, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-inject-fail", contracthost.RestoreSnapshotRequest{ + MachineID: "restored-inject-fail", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + Snapshot: &contracthost.DurableSnapshotSpec{ + SnapshotID: "snap-inject-fail", + MachineID: "source", + ImageID: "image-1", + SourceRuntimeHost: "172.16.0.2", + SourceTapDevice: "fctap0", + Artifacts: []contracthost.SnapshotArtifact{ + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"}, + }, + }, + }) + if err == nil || !strings.Contains(err.Error(), "inject machine identity for restore") { + t.Fatalf("RestoreSnapshot error = %v, want inject machine identity failure", err) + } + + machineDiskDir := filepath.Join(cfg.MachineDisksDir, "restored-inject-fail") + if _, statErr := os.Stat(machineDiskDir); !os.IsNotExist(statErr) { + t.Fatalf("machine disk dir should be cleaned up, stat err = %v", statErr) + } +} + func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) { root := t.TempDir() cfg := testConfig(root) diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index 1054194..f1fb276 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -256,6 +256,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil { return nil, fmt.Errorf("create machine disk dir: %w", err) } + removeMachineDiskDirOnFailure := true + defer func() { + if !removeMachineDiskDirOnFailure { + return + } + _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) + }() systemDiskPath, ok := restoredArtifacts["system.img"] if !ok { clearOperation = true @@ -265,11 +272,11 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn clearOperation = true return nil, fmt.Errorf("copy system disk for restore: %w", err) } - if err := injectMachineIdentity(ctx, newSystemDiskPath, req.MachineID); err != nil { + if err := d.injectMachineIdentity(ctx, newSystemDiskPath, req.MachineID); err != nil { clearOperation = true return nil, fmt.Errorf("inject machine identity for restore: %w", err) } - if err := injectGuestConfig(ctx, newSystemDiskPath, guestConfig); err != nil { + if err := d.injectGuestConfig(ctx, newSystemDiskPath, guestConfig); err != nil { clearOperation = true return nil, fmt.Errorf("inject guest config for restore: %w", err) } @@ -308,19 +315,16 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig) if err != nil { - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true return nil, fmt.Errorf("build machine spec for restore: %w", err) } usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) if err != nil { - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true return nil, err } machineState, err := d.runtime.Boot(ctx, spec, usedNetworks) if err != nil { - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true return nil, fmt.Errorf("boot restored machine: %w", err) } @@ -338,7 +342,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn CreatedAt: now, }); err != nil { _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true return nil, fmt.Errorf("create system volume record for restore: %w", err) } @@ -358,7 +361,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } _ = d.store.DeleteVolume(context.Background(), systemVolumeID) _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true return nil, fmt.Errorf("create restored user volume record %q: %w", volume.ID, err) } @@ -387,11 +389,11 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn } _ = d.store.DeleteVolume(context.Background(), systemVolumeID) _ = d.runtime.Delete(ctx, *machineState) - _ = os.RemoveAll(filepath.Dir(newSystemDiskPath)) clearOperation = true return nil, err } + removeMachineDiskDirOnFailure = false clearOperation = true return &contracthost.RestoreSnapshotResponse{ Machine: machineToContract(machineRecord),