diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index ba1d7b7..e55e56a 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -233,15 +233,16 @@ func (d *Daemon) Reconcile(ctx context.Context) error { return err } for _, record := range records { - if _, err := d.reconcileMachine(ctx, record.ID); err != nil { + reconciled, err := d.reconcileMachine(ctx, record.ID) + if err != nil { return err } - if record.Phase == contracthost.MachinePhaseRunning { - if err := d.ensurePublishedPortsForMachine(ctx, record); err != nil { + if reconciled.Phase == contracthost.MachinePhaseRunning { + if err := d.ensurePublishedPortsForMachine(ctx, *reconciled); err != nil { return err } } else { - d.stopPublishedPortsForMachine(record.ID) + d.stopPublishedPortsForMachine(reconciled.ID) } } return nil diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 22cbd1d..08c7082 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -3,6 +3,7 @@ package daemon import ( "context" "errors" + "net" "os" "path/filepath" "strings" @@ -10,6 +11,7 @@ import ( "testing" "time" + "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" hoststore "github.com/getcompanion-ai/computer-host/internal/store" contracthost "github.com/getcompanion-ai/computer-host/contract" @@ -57,6 +59,16 @@ type publishedPortResult struct { err error } +type failingInspectRuntime struct { + fakeRuntime +} + +func (r *failingInspectRuntime) Inspect(state firecracker.MachineState) (*firecracker.MachineState, error) { + state.Phase = firecracker.PhaseFailed + state.Error = "vm exited unexpectedly" + return &state, nil +} + func TestCreatePublishedPortSerializesHostPortAllocationAcrossMachines(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -267,6 +279,132 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) assertOperationCount(t, baseStore, 1) } +func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifacts(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, fileStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + machineID := contracthost.MachineID("vm-1") + snapshotID := contracthost.SnapshotID("snap-1") + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: machineID, + SystemVolumeID: hostDaemon.systemVolumeID(machineID), + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", string(machineID), "root", "run", "firecracker.sock"), + Phase: contracthost.MachinePhaseRunning, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create machine: %v", err) + } + + snapshotDir := filepath.Join(cfg.SnapshotsDir, string(snapshotID)) + if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + markerPath := filepath.Join(snapshotDir, "keep.txt") + if err := os.WriteFile(markerPath, []byte("keep"), 0o644); err != nil { + t.Fatalf("write marker file: %v", err) + } + if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: snapshotID, + MachineID: machineID, + MemFilePath: markerPath, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + _, err = hostDaemon.CreateSnapshot(context.Background(), machineID, contracthost.CreateSnapshotRequest{ + SnapshotID: snapshotID, + }) + if err == nil || !strings.Contains(err.Error(), "already exists") { + t.Fatalf("CreateSnapshot error = %v, want duplicate snapshot failure", err) + } + if _, statErr := os.Stat(markerPath); statErr != nil { + t.Fatalf("marker file should be preserved, stat error: %v", statErr) + } + assertOperationCount(t, fileStore, 0) +} + +func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, fileStore, &failingInspectRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + t.Cleanup(func() { + hostDaemon.stopPublishedPortProxy("port-1") + }) + + reserved, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("reserve host port: %v", err) + } + hostPort := uint16(reserved.Addr().(*net.TCPAddr).Port) + if err := reserved.Close(); err != nil { + t.Fatalf("close reserved host port: %v", err) + } + + machineID := contracthost.MachineID("vm-1") + if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ + ID: machineID, + RuntimeHost: "127.0.0.1", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", string(machineID), "root", "run", "firecracker.sock"), + TapDevice: "fctap0", + Phase: contracthost.MachinePhaseRunning, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create machine: %v", err) + } + if err := fileStore.CreatePublishedPort(context.Background(), model.PublishedPortRecord{ + ID: "port-1", + MachineID: machineID, + Port: 8080, + HostPort: hostPort, + Protocol: contracthost.PortProtocolTCP, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create published port: %v", err) + } + + if err := hostDaemon.Reconcile(context.Background()); err != nil { + t.Fatalf("Reconcile returned error: %v", err) + } + + updated, err := fileStore.GetMachine(context.Background(), machineID) + if err != nil { + t.Fatalf("get machine after reconcile: %v", err) + } + if updated.Phase != contracthost.MachinePhaseFailed { + t.Fatalf("machine phase = %q, want failed", updated.Phase) + } + if updated.RuntimeHost != "" { + t.Fatalf("machine runtime host = %q, want cleared", updated.RuntimeHost) + } + + hostDaemon.publishedPortsMu.Lock() + listenerCount := len(hostDaemon.publishedPortListeners) + hostDaemon.publishedPortsMu.Unlock() + if listenerCount != 0 { + t.Fatalf("published port listeners = %d, want 0", listenerCount) + } +} + func waitPublishedPortResult(t *testing.T, ch <-chan publishedPortResult) publishedPortResult { t.Helper() diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index d868291..ac92944 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -42,6 +42,11 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach } snapshotID := req.SnapshotID + if _, err := d.store.GetSnapshot(ctx, snapshotID); err == nil { + return nil, fmt.Errorf("snapshot %q already exists", snapshotID) + } else if err != nil && err != store.ErrNotFound { + return nil, err + } if err := d.store.UpsertOperation(ctx, model.OperationRecord{ MachineID: machineID, @@ -60,7 +65,10 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach }() snapshotDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID)) - if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + if err := os.Mkdir(snapshotDir, 0o755); err != nil { + if os.IsExist(err) { + return nil, fmt.Errorf("snapshot %q already exists", snapshotID) + } return nil, fmt.Errorf("create snapshot dir: %w", err) }