fix: reconcile race fix, disk regression for snapshot deletion

This commit is contained in:
Harivansh Rathi 2026-04-13 02:56:54 +00:00
parent 09d9e7c23b
commit 218cc3fecb
11 changed files with 193 additions and 123 deletions

View file

@ -1,6 +1,6 @@
## computer-host ## computer-host
<img width="3588" height="1184" alt="Gemini_Generated_Image_yxb12yyxb12yyxb1" src="https://github.com/user-attachments/assets/f7e6d927-568f-4a94-99a9-4664d1fc43f5" /> <img width="3588" height="1184" alt="Gemini_Generated_Image_yxb12yyxb12yyxb1" src="https://github.com/user-attachments/assets/ccba52b9-0229-44e6-a4c5-a34816041284" />
computer-host is a daemon runtime for managing Firecracker microVMs computer-host is a daemon runtime for managing Firecracker microVMs

View file

@ -118,6 +118,20 @@ func (d *Daemon) Health(ctx context.Context) (*contracthost.HealthResponse, erro
} }
func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() { func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() {
lock := d.machineLock(machineID)
lock.Lock()
return lock.Unlock
}
func (d *Daemon) tryLockMachine(machineID contracthost.MachineID) (func(), bool) {
lock := d.machineLock(machineID)
if !lock.TryLock() {
return nil, false
}
return lock.Unlock, true
}
func (d *Daemon) machineLock(machineID contracthost.MachineID) *sync.Mutex {
d.locksMu.Lock() d.locksMu.Lock()
lock, ok := d.machineLocks[machineID] lock, ok := d.machineLocks[machineID]
if !ok { if !ok {
@ -125,9 +139,7 @@ func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() {
d.machineLocks[machineID] = lock d.machineLocks[machineID] = lock
} }
d.locksMu.Unlock() d.locksMu.Unlock()
return lock
lock.Lock()
return lock.Unlock
} }
func (d *Daemon) lockArtifact(key string) func() { func (d *Daemon) lockArtifact(key string) func() {

View file

@ -944,6 +944,50 @@ func TestGetSnapshotArtifactReturnsLocalArtifactPath(t *testing.T) {
} }
} }
func TestDeleteSnapshotByIDRemovesDiskOnlySnapshotDirectory(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
hostDaemon, err := New(cfg, fileStore, &fakeRuntime{})
if err != nil {
t.Fatalf("create daemon: %v", err)
}
snapshotDir := filepath.Join(root, "snapshots", "snap-delete")
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
systemPath := filepath.Join(snapshotDir, "system.img")
if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil {
t.Fatalf("write system disk: %v", err)
}
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap-delete",
MachineID: "source",
DiskPaths: []string{systemPath},
Artifacts: []model.SnapshotArtifactRecord{
{ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4},
},
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create snapshot: %v", err)
}
if err := hostDaemon.DeleteSnapshotByID(context.Background(), "snap-delete"); err != nil {
t.Fatalf("DeleteSnapshotByID returned error: %v", err)
}
if _, err := os.Stat(snapshotDir); !os.IsNotExist(err) {
t.Fatalf("snapshot dir should be removed, stat error: %v", err)
}
if _, err := fileStore.GetSnapshot(context.Background(), "snap-delete"); err != store.ErrNotFound {
t.Fatalf("snapshot should be removed from store, got: %v", err)
}
}
func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
root := t.TempDir() root := t.TempDir()
cfg := testConfig(root) cfg := testConfig(root)

View file

@ -523,7 +523,7 @@ func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guest
if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil { if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil {
return fmt.Errorf("write guest ssh host private key staging file: %w", err) return fmt.Errorf("write guest ssh host private key staging file: %w", err)
} }
if err := replaceExt4File(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key"); err != nil { if err := replaceExt4FileMode(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key", "100600"); err != nil {
return err return err
} }
@ -531,7 +531,7 @@ func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guest
if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil { if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil {
return fmt.Errorf("write guest ssh host public key staging file: %w", err) return fmt.Errorf("write guest ssh host public key staging file: %w", err)
} }
if err := replaceExt4File(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub"); err != nil { if err := replaceExt4FileMode(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub", "100644"); err != nil {
return err return err
} }
@ -543,6 +543,7 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont
if machineName == "" { if machineName == "" {
return fmt.Errorf("machine_id is required") return fmt.Errorf("machine_id is required")
} }
hostname := "agentcomputer"
stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "machine-identity-*") stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "machine-identity-*")
if err != nil { if err != nil {
@ -553,11 +554,11 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont
}() }()
identityFiles := map[string]string{ identityFiles := map[string]string{
"/etc/microagent/machine-name": machineName + "\n", "/etc/microagent/machine-name": hostname + "\n",
"/etc/hostname": machineName + "\n", "/etc/hostname": hostname + "\n",
"/etc/hosts": fmt.Sprintf( "/etc/hosts": fmt.Sprintf(
"127.0.0.1 localhost\n127.0.1.1 %s\n::1 localhost ip6-localhost ip6-loopback\nff02::1 ip6-allnodes\nff02::2 ip6-allrouters\n", "127.0.0.1 localhost\n127.0.1.1 %s\n::1 localhost ip6-localhost ip6-loopback\nff02::1 ip6-allnodes\nff02::2 ip6-allrouters\n",
machineName, hostname,
), ),
} }
@ -576,10 +577,19 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont
} }
func replaceExt4File(ctx context.Context, imagePath string, sourcePath string, targetPath string) error { func replaceExt4File(ctx context.Context, imagePath string, sourcePath string, targetPath string) error {
return replaceExt4FileMode(ctx, imagePath, sourcePath, targetPath, "")
}
func replaceExt4FileMode(ctx context.Context, imagePath string, sourcePath string, targetPath string, mode string) error {
_ = runDebugFS(ctx, imagePath, fmt.Sprintf("rm %s", targetPath)) _ = runDebugFS(ctx, imagePath, fmt.Sprintf("rm %s", targetPath))
if err := runDebugFS(ctx, imagePath, fmt.Sprintf("write %s %s", sourcePath, targetPath)); err != nil { if err := runDebugFS(ctx, imagePath, fmt.Sprintf("write %s %s", sourcePath, targetPath)); err != nil {
return fmt.Errorf("inject %q into %q: %w", targetPath, imagePath, err) return fmt.Errorf("inject %q into %q: %w", targetPath, imagePath, err)
} }
if mode != "" {
if err := runDebugFS(ctx, imagePath, fmt.Sprintf("set_inode_field %s mode 0%s", targetPath, mode)); err != nil {
return fmt.Errorf("set mode on %q in %q: %w", targetPath, imagePath, err)
}
}
return nil return nil
} }

View file

@ -81,15 +81,15 @@ func TestInjectMachineIdentityWritesHostnameFiles(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("read hostname: %v", err) t.Fatalf("read hostname: %v", err)
} }
if hostname != "kiruru\n" { if hostname != "agentcomputer\n" {
t.Fatalf("hostname mismatch: got %q want %q", hostname, "kiruru\n") t.Fatalf("hostname mismatch: got %q want %q", hostname, "agentcomputer\n")
} }
hosts, err := readExt4File(imagePath, "/etc/hosts") hosts, err := readExt4File(imagePath, "/etc/hosts")
if err != nil { if err != nil {
t.Fatalf("read hosts: %v", err) t.Fatalf("read hosts: %v", err)
} }
if !strings.Contains(hosts, "127.0.1.1 kiruru") { if !strings.Contains(hosts, "127.0.1.1 agentcomputer") {
t.Fatalf("hosts missing machine name: %q", hosts) t.Fatalf("hosts missing machine name: %q", hosts)
} }
} }

View file

@ -23,7 +23,7 @@ const (
defaultGuestPersonalizationVsockID = "microagent-personalizer" defaultGuestPersonalizationVsockID = "microagent-personalizer"
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock" defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
defaultGuestPersonalizationVsockPort = uint32(1024) defaultGuestPersonalizationVsockPort = uint32(1024)
defaultGuestPersonalizationTimeout = 15 * time.Second defaultGuestPersonalizationTimeout = 30 * time.Second
guestPersonalizationRetryInterval = 100 * time.Millisecond guestPersonalizationRetryInterval = 100 * time.Millisecond
minGuestVsockCID = uint32(3) minGuestVsockCID = uint32(3)
maxGuestVsockCID = uint32(1<<31 - 1) maxGuestVsockCID = uint32(1<<31 - 1)
@ -91,9 +91,34 @@ func sendGuestPersonalization(ctx context.Context, state firecracker.MachineStat
if err != nil { if err != nil {
return nil, err return nil, err
} }
connection, err := dialGuestPersonalization(ctx, vsockPath)
var lastErr error
for {
if ctx.Err() != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, ctx.Err()
}
resp, err := tryGuestPersonalization(ctx, vsockPath, payloadBytes)
if err == nil {
return resp, nil
}
lastErr = err
select {
case <-ctx.Done():
return nil, lastErr
case <-time.After(guestPersonalizationRetryInterval):
}
}
}
func tryGuestPersonalization(ctx context.Context, vsockPath string, payloadBytes []byte) (*guestPersonalizationResponse, error) {
connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
} }
defer func() { defer func() {
_ = connection.Close() _ = connection.Close()
@ -140,25 +165,3 @@ func setConnectionDeadline(ctx context.Context, connection net.Conn) {
} }
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout)) _ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
} }
func dialGuestPersonalization(ctx context.Context, vsockPath string) (net.Conn, error) {
dialer := &net.Dialer{}
for {
connection, err := dialer.DialContext(ctx, "unix", vsockPath)
if err == nil {
return connection, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(guestPersonalizationRetryInterval):
}
}
}

View file

@ -52,6 +52,9 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
} }
if record.Phase == contracthost.MachinePhaseStarting { if record.Phase == contracthost.MachinePhaseStarting {
// reconcileMachine acquires the machine lock, so we must release
// ours first to avoid self-deadlock.
unlock()
reconciled, err := d.reconcileMachine(ctx, id) reconciled, err := d.reconcileMachine(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@ -220,6 +223,12 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
return err return err
} }
for _, operation := range operations { for _, operation := range operations {
unlock, ok := d.tryLockMachine(operation.MachineID)
if !ok {
continue
}
unlock()
switch operation.Type { switch operation.Type {
case model.MachineOperationCreate: case model.MachineOperationCreate:
if err := d.reconcileCreate(ctx, operation.MachineID); err != nil { if err := d.reconcileCreate(ctx, operation.MachineID); err != nil {

View file

@ -1,46 +0,0 @@
package daemon
import (
"context"
"fmt"
"net"
"strconv"
"strings"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) {
host = strings.TrimSpace(host)
if host == "" {
return false, fmt.Errorf("guest runtime host is required")
}
for _, port := range ports {
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
ready, err := guestPortReady(probeCtx, host, port)
cancel()
if err != nil {
return false, err
}
if !ready {
return false, nil
}
}
return true, nil
}
func guestPortReady(ctx context.Context, host string, port contracthost.MachinePort) (bool, error) {
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
dialer := net.Dialer{Timeout: defaultGuestDialTimeout}
connection, err := dialer.DialContext(ctx, string(port.Protocol), address)
if err == nil {
_ = connection.Close()
return true, nil
}
if ctx.Err() != nil {
return false, nil
}
return false, nil
}

View file

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -57,19 +56,6 @@ func (s machineLookupErrorStore) GetMachine(context.Context, contracthost.Machin
return nil, s.err return nil, s.err
} }
type relayExhaustionStore struct {
hoststore.Store
extraMachines []model.MachineRecord
}
func (s relayExhaustionStore) ListMachines(ctx context.Context) ([]model.MachineRecord, error) {
machines, err := s.Store.ListMachines(ctx)
if err != nil {
return nil, err
}
return append(machines, s.extraMachines...), nil
}
type publishedPortResult struct { type publishedPortResult struct {
response *contracthost.CreatePublishedPortResponse response *contracthost.CreatePublishedPortResponse
err error err error
@ -255,6 +241,52 @@ func TestReconcileSnapshotPreservesArtifactsOnUnexpectedStoreError(t *testing.T)
assertOperationCount(t, baseStore, 1) assertOperationCount(t, baseStore, 1)
} }
func TestReconcileSkipsInFlightSnapshotOperationWhileMachineLocked(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
hostDaemon, err := New(cfg, baseStore, &fakeRuntime{})
if err != nil {
t.Fatalf("create daemon: %v", err)
}
stubGuestSSHPublicKeyReader(hostDaemon)
snapshotID := contracthost.SnapshotID("snap-inflight")
operation := model.OperationRecord{
MachineID: "vm-1",
Type: model.MachineOperationSnapshot,
StartedAt: time.Now().UTC(),
SnapshotID: &snapshotID,
}
if err := baseStore.UpsertOperation(context.Background(), operation); err != nil {
t.Fatalf("upsert operation: %v", err)
}
snapshotDir := filepath.Join(cfg.SnapshotsDir, string(snapshotID))
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
markerPath := filepath.Join(snapshotDir, "keep.txt")
if err := os.WriteFile(markerPath, []byte("keep"), 0o644); err != nil {
t.Fatalf("write marker file: %v", err)
}
unlock := hostDaemon.lockMachine("vm-1")
defer unlock()
if err := hostDaemon.Reconcile(context.Background()); err != nil {
t.Fatalf("Reconcile returned error: %v", err)
}
if _, statErr := os.Stat(markerPath); statErr != nil {
t.Fatalf("in-flight snapshot artifacts should be preserved, stat error: %v", statErr)
}
assertOperationCount(t, baseStore, 1)
}
func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) { func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) {
root := t.TempDir() root := t.TempDir()
cfg := testConfig(root) cfg := testConfig(root)
@ -307,11 +339,6 @@ func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) {
t.Fatalf("create file store: %v", err) t.Fatalf("create file store: %v", err)
} }
exhaustedStore := relayExhaustionStore{
Store: baseStore,
extraMachines: exhaustedMachineRelayRecords(),
}
sshListener := listenTestPort(t, int(defaultSSHPort)) sshListener := listenTestPort(t, int(defaultSSHPort))
defer func() { defer func() {
_ = sshListener.Close() _ = sshListener.Close()
@ -334,7 +361,7 @@ func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) {
}, },
} }
hostDaemon, err := New(cfg, exhaustedStore, runtime) hostDaemon, err := New(cfg, baseStore, runtime)
if err != nil { if err != nil {
t.Fatalf("create daemon: %v", err) t.Fatalf("create daemon: %v", err)
} }
@ -416,11 +443,6 @@ func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) {
t.Fatalf("create file store: %v", err) t.Fatalf("create file store: %v", err)
} }
exhaustedStore := relayExhaustionStore{
Store: baseStore,
extraMachines: exhaustedMachineRelayRecords(),
}
startedAt := time.Unix(1700000300, 0).UTC() startedAt := time.Unix(1700000300, 0).UTC()
runtime := &fakeRuntime{ runtime := &fakeRuntime{
bootState: firecracker.MachineState{ bootState: firecracker.MachineState{
@ -434,7 +456,7 @@ func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) {
}, },
} }
hostDaemon, err := New(cfg, exhaustedStore, runtime) hostDaemon, err := New(cfg, baseStore, runtime)
if err != nil { if err != nil {
t.Fatalf("create daemon: %v", err) t.Fatalf("create daemon: %v", err)
} }
@ -909,19 +931,6 @@ func waitPublishedPortResult(t *testing.T, ch <-chan publishedPortResult) publis
} }
} }
func exhaustedMachineRelayRecords() []model.MachineRecord {
count := int(maxMachineSSHRelayPort-minMachineSSHRelayPort) + 1
machines := make([]model.MachineRecord, 0, count)
for i := 0; i < count; i++ {
machines = append(machines, model.MachineRecord{
ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)),
Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i), 0),
Phase: contracthost.MachinePhaseRunning,
})
}
return machines
}
func mustSHA256Hex(t *testing.T, payload []byte) string { func mustSHA256Hex(t *testing.T, payload []byte) string {
t.Helper() t.Helper()

View file

@ -489,7 +489,10 @@ func (d *Daemon) DeleteSnapshotByID(ctx context.Context, snapshotID contracthost
if err != nil { if err != nil {
return err return err
} }
snapshotDir := filepath.Dir(snap.MemFilePath) snapshotDir, ok := snapshotDirectory(*snap)
if !ok {
return fmt.Errorf("snapshot %q has no local artifact directory", snapshotID)
}
if err := os.RemoveAll(snapshotDir); err != nil { if err := os.RemoveAll(snapshotDir); err != nil {
return fmt.Errorf("remove snapshot dir %q: %w", snapshotDir, err) return fmt.Errorf("remove snapshot dir %q: %w", snapshotDir, err)
} }
@ -520,6 +523,25 @@ func snapshotArtifactsToContract(artifacts []model.SnapshotArtifactRecord) []con
return converted return converted
} }
func snapshotDirectory(snapshot model.SnapshotRecord) (string, bool) {
for _, artifact := range snapshot.Artifacts {
if path := strings.TrimSpace(artifact.LocalPath); path != "" {
return filepath.Dir(path), true
}
}
for _, diskPath := range snapshot.DiskPaths {
if path := strings.TrimSpace(diskPath); path != "" {
return filepath.Dir(path), true
}
}
for _, legacyPath := range []string{snapshot.MemFilePath, snapshot.StateFilePath} {
if path := strings.TrimSpace(legacyPath); path != "" {
return filepath.Dir(path), true
}
}
return "", false
}
func orderedRestoredUserDiskArtifacts(artifacts map[string]restoredSnapshotArtifact) []restoredSnapshotArtifact { func orderedRestoredUserDiskArtifacts(artifacts map[string]restoredSnapshotArtifact) []restoredSnapshotArtifact {
ordered := make([]restoredSnapshotArtifact, 0, len(artifacts)) ordered := make([]restoredSnapshotArtifact, 0, len(artifacts))
for name, artifact := range artifacts { for name, artifact := range artifacts {

View file

@ -73,5 +73,12 @@ func (d *Daemon) completeMachineStartup(ctx context.Context, record *model.Machi
if err := d.store.UpdateMachine(ctx, *record); err != nil { if err := d.store.UpdateMachine(ctx, *record); err != nil {
return nil, err return nil, err
} }
if err := d.ensureMachineRelays(ctx, record); err != nil {
return d.failMachineStartup(ctx, record, err.Error())
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
d.stopMachineRelays(record.ID)
return d.failMachineStartup(ctx, record, err.Error())
}
return record, nil return record, nil
} }