chore: disk removal review fixes

This commit is contained in:
Harivansh Rathi 2026-04-11 21:06:37 +00:00
parent d0f0530ca2
commit 0e4b18f10b
6 changed files with 230 additions and 35 deletions

View file

@ -69,15 +69,6 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
if err := cloneDiskFile(artifact.RootFSPath, systemVolumePath, d.config.DiskCloneMode); err != nil {
return nil, fmt.Errorf("clone rootfs for %q: %w", req.MachineID, err)
}
if err := os.Truncate(systemVolumePath, defaultGuestDiskSizeBytes); err != nil {
return nil, fmt.Errorf("expand system volume for %q: %w", req.MachineID, err)
}
if err := injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil {
return nil, fmt.Errorf("inject machine identity for %q: %w", req.MachineID, err)
}
if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err)
}
removeSystemVolumeOnFailure := true
defer func() {
if !removeSystemVolumeOnFailure {
@ -86,6 +77,15 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
_ = os.Remove(systemVolumePath)
_ = os.RemoveAll(filepath.Dir(systemVolumePath))
}()
if err := os.Truncate(systemVolumePath, defaultGuestDiskSizeBytes); err != nil {
return nil, fmt.Errorf("expand system volume for %q: %w", req.MachineID, err)
}
if err := d.injectMachineIdentity(ctx, systemVolumePath, req.MachineID); err != nil {
return nil, fmt.Errorf("inject machine identity for %q: %w", req.MachineID, err)
}
if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err)
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig)
if err != nil {

View file

@ -46,6 +46,8 @@ type Daemon struct {
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error
readGuestSSHPublicKey func(context.Context, string) (string, error)
injectMachineIdentity func(context.Context, string, contracthost.MachineID) error
injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error
syncGuestFilesystem func(context.Context, string) error
shutdownGuest func(context.Context, string) error
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error
@ -86,6 +88,8 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
runtime: runtime,
reconfigureGuestIdentity: nil,
readGuestSSHPublicKey: nil,
injectMachineIdentity: nil,
injectGuestConfig: nil,
personalizeGuest: nil,
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
artifactLocks: make(map[string]*sync.Mutex),
@ -94,6 +98,8 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
}
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
daemon.injectMachineIdentity = injectMachineIdentity
daemon.injectGuestConfig = injectGuestConfig
daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH
daemon.shutdownGuest = daemon.issueGuestPoweroff
daemon.personalizeGuest = daemon.personalizeGuestConfig

View file

@ -263,6 +263,48 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
}
}
func TestCreateMachineCleansSystemVolumeOnInjectFailure(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
hostDaemon, err := New(cfg, fileStore, &fakeRuntime{})
if err != nil {
t.Fatalf("create daemon: %v", err)
}
hostDaemon.injectMachineIdentity = func(context.Context, string, contracthost.MachineID) error {
return errors.New("inject failed")
}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel-image"),
"/rootfs": buildTestExt4ImageBytes(t),
})
defer server.Close()
_, err = hostDaemon.CreateMachine(context.Background(), contracthost.CreateMachineRequest{
MachineID: "vm-inject-fail",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
})
if err == nil || !strings.Contains(err.Error(), "inject machine identity") {
t.Fatalf("CreateMachine error = %v, want inject machine identity failure", err)
}
systemVolumePath := hostDaemon.systemVolumePath("vm-inject-fail")
if _, statErr := os.Stat(systemVolumePath); !os.IsNotExist(statErr) {
t.Fatalf("system volume should be cleaned up, stat err = %v", statErr)
}
if _, statErr := os.Stat(filepath.Dir(systemVolumePath)); !os.IsNotExist(statErr) {
t.Fatalf("system volume dir should be cleaned up, stat err = %v", statErr)
}
}
func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
@ -337,7 +379,7 @@ func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) {
}
}
func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) {
func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -386,10 +428,6 @@ func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) {
t.Fatalf("create machine: %v", err)
}
if err := hostDaemon.Reconcile(context.Background()); err != nil {
t.Fatalf("Reconcile returned error: %v", err)
}
response, err := hostDaemon.GetMachine(context.Background(), "vm-starting")
if err != nil {
t.Fatalf("GetMachine returned error: %v", err)
@ -515,6 +553,92 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
}
}
func TestShutdownGuestCleanChecksGuestStateAfterPoweroffTimeout(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{
inspectOverride: func(state firecracker.MachineState) (*firecracker.MachineState, error) {
state.Phase = firecracker.PhaseStopped
state.PID = 0
return &state, nil
},
}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
hostDaemon.shutdownGuest = func(context.Context, string) error {
return context.DeadlineExceeded
}
now := time.Now().UTC()
record := &model.MachineRecord{
ID: "vm-timeout",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap-timeout",
Phase: contracthost.MachinePhaseRunning,
PID: 1234,
SocketPath: filepath.Join(root, "runtime", "vm-timeout.sock"),
Ports: defaultMachinePorts(),
CreatedAt: now,
StartedAt: &now,
}
if ok := hostDaemon.shutdownGuestClean(context.Background(), record); !ok {
t.Fatal("shutdownGuestClean should treat a timed-out poweroff as success when the VM is already stopped")
}
}
func TestShutdownGuestCleanRespectsContextCancellation(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{
inspectOverride: func(state firecracker.MachineState) (*firecracker.MachineState, error) {
state.Phase = firecracker.PhaseRunning
return &state, nil
},
}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
hostDaemon.shutdownGuest = func(context.Context, string) error { return nil }
now := time.Now().UTC()
record := &model.MachineRecord{
ID: "vm-cancel",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap-cancel",
Phase: contracthost.MachinePhaseRunning,
PID: 1234,
SocketPath: filepath.Join(root, "runtime", "vm-cancel.sock"),
Ports: defaultMachinePorts(),
CreatedAt: now,
StartedAt: &now,
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
if ok := hostDaemon.shutdownGuestClean(ctx, record); ok {
t.Fatal("shutdownGuestClean should not report a clean shutdown after cancellation")
}
if elapsed := time.Since(start); elapsed > time.Second {
t.Fatalf("shutdownGuestClean took %v after cancellation, want fast return", elapsed)
}
}
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)

View file

@ -10,14 +10,14 @@ import (
"strings"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
record, err := d.store.GetMachine(ctx, id)
record, err := d.reconcileMachine(ctx, id)
if err != nil {
return nil, err
}
@ -527,26 +527,37 @@ func (d *Daemon) shutdownGuestClean(ctx context.Context, record *model.MachineRe
defer cancel()
if err := d.shutdownGuest(shutdownCtx, record.RuntimeHost); err != nil {
fmt.Fprintf(os.Stderr, "warning: guest poweroff for %q failed: %v\n", record.ID, err)
return false
if ctx.Err() != nil {
return false
}
if shutdownCtx.Err() == nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
fmt.Fprintf(os.Stderr, "warning: guest poweroff for %q failed: %v\n", record.ID, err)
return false
}
fmt.Fprintf(os.Stderr, "warning: guest poweroff for %q timed out before confirmation; checking whether shutdown is already in progress: %v\n", record.ID, err)
}
deadline := time.After(defaultGuestStopTimeout)
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
state, err := d.runtime.Inspect(machineToRuntimeState(*record))
if err != nil {
return false
}
if state.Phase != firecracker.PhaseRunning {
return true
}
select {
case <-deadline:
case <-ctx.Done():
return false
case <-shutdownCtx.Done():
if ctx.Err() != nil {
return false
}
fmt.Fprintf(os.Stderr, "warning: guest %q did not exit within stop window; forcing teardown\n", record.ID)
return false
case <-ticker.C:
state, err := d.runtime.Inspect(machineToRuntimeState(*record))
if err != nil {
return false
}
if state.Phase != firecracker.PhaseRunning {
return true
}
}
}
}

View file

@ -773,6 +773,58 @@ func TestRestoreSnapshotCleansStagingArtifactsAfterDownloadFailure(t *testing.T)
}
}
func TestRestoreSnapshotCleansMachineDiskDirOnInjectFailure(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
hostDaemon.injectMachineIdentity = func(context.Context, string, contracthost.MachineID) error {
return errors.New("inject failed")
}
server := newRestoreArtifactServer(t, map[string][]byte{
"/kernel": []byte("kernel"),
"/rootfs": []byte("rootfs"),
"/system": buildTestExt4ImageBytes(t),
})
defer server.Close()
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-inject-fail", contracthost.RestoreSnapshotRequest{
MachineID: "restored-inject-fail",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
Snapshot: &contracthost.DurableSnapshotSpec{
SnapshotID: "snap-inject-fail",
MachineID: "source",
ImageID: "image-1",
SourceRuntimeHost: "172.16.0.2",
SourceTapDevice: "fctap0",
Artifacts: []contracthost.SnapshotArtifact{
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", DownloadURL: server.URL + "/system"},
},
},
})
if err == nil || !strings.Contains(err.Error(), "inject machine identity for restore") {
t.Fatalf("RestoreSnapshot error = %v, want inject machine identity failure", err)
}
machineDiskDir := filepath.Join(cfg.MachineDisksDir, "restored-inject-fail")
if _, statErr := os.Stat(machineDiskDir); !os.IsNotExist(statErr) {
t.Fatalf("machine disk dir should be cleaned up, stat err = %v", statErr)
}
}
func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)

View file

@ -256,6 +256,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil {
return nil, fmt.Errorf("create machine disk dir: %w", err)
}
removeMachineDiskDirOnFailure := true
defer func() {
if !removeMachineDiskDirOnFailure {
return
}
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
}()
systemDiskPath, ok := restoredArtifacts["system.img"]
if !ok {
clearOperation = true
@ -265,11 +272,11 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
clearOperation = true
return nil, fmt.Errorf("copy system disk for restore: %w", err)
}
if err := injectMachineIdentity(ctx, newSystemDiskPath, req.MachineID); err != nil {
if err := d.injectMachineIdentity(ctx, newSystemDiskPath, req.MachineID); err != nil {
clearOperation = true
return nil, fmt.Errorf("inject machine identity for restore: %w", err)
}
if err := injectGuestConfig(ctx, newSystemDiskPath, guestConfig); err != nil {
if err := d.injectGuestConfig(ctx, newSystemDiskPath, guestConfig); err != nil {
clearOperation = true
return nil, fmt.Errorf("inject guest config for restore: %w", err)
}
@ -308,19 +315,16 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("build machine spec for restore: %w", err)
}
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, err
}
machineState, err := d.runtime.Boot(ctx, spec, usedNetworks)
if err != nil {
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("boot restored machine: %w", err)
}
@ -338,7 +342,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
CreatedAt: now,
}); err != nil {
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("create system volume record for restore: %w", err)
}
@ -358,7 +361,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, fmt.Errorf("create restored user volume record %q: %w", volume.ID, err)
}
@ -387,11 +389,11 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
_ = d.runtime.Delete(ctx, *machineState)
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
clearOperation = true
return nil, err
}
removeMachineDiskDirOnFailure = false
clearOperation = true
return &contracthost.RestoreSnapshotResponse{
Machine: machineToContract(machineRecord),