diff --git a/README.md b/README.md index 20eb45a..dffba85 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ## computer-host -Gemini_Generated_Image_yxb12yyxb12yyxb1 +Gemini_Generated_Image_yxb12yyxb12yyxb1 computer-host is a daemon runtime for managing Firecracker microVMs diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index aeef596..2ecff19 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -118,6 +118,20 @@ func (d *Daemon) Health(ctx context.Context) (*contracthost.HealthResponse, erro } func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() { + lock := d.machineLock(machineID) + lock.Lock() + return lock.Unlock +} + +func (d *Daemon) tryLockMachine(machineID contracthost.MachineID) (func(), bool) { + lock := d.machineLock(machineID) + if !lock.TryLock() { + return nil, false + } + return lock.Unlock, true +} + +func (d *Daemon) machineLock(machineID contracthost.MachineID) *sync.Mutex { d.locksMu.Lock() lock, ok := d.machineLocks[machineID] if !ok { @@ -125,9 +139,7 @@ func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() { d.machineLocks[machineID] = lock } d.locksMu.Unlock() - - lock.Lock() - return lock.Unlock + return lock } func (d *Daemon) lockArtifact(key string) func() { diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index e29d840..2c7bf65 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -944,6 +944,50 @@ func TestGetSnapshotArtifactReturnsLocalArtifactPath(t *testing.T) { } } +func TestDeleteSnapshotByIDRemovesDiskOnlySnapshotDirectory(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, fileStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + snapshotDir := filepath.Join(root, "snapshots", "snap-delete") + if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + systemPath := filepath.Join(snapshotDir, "system.img") + if err := os.WriteFile(systemPath, []byte("disk"), 0o644); err != nil { + t.Fatalf("write system disk: %v", err) + } + if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: "snap-delete", + MachineID: "source", + DiskPaths: []string{systemPath}, + Artifacts: []model.SnapshotArtifactRecord{ + {ID: "disk-system", Kind: contracthost.SnapshotArtifactKindDisk, Name: "system.img", LocalPath: systemPath, SizeBytes: 4}, + }, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + if err := hostDaemon.DeleteSnapshotByID(context.Background(), "snap-delete"); err != nil { + t.Fatalf("DeleteSnapshotByID returned error: %v", err) + } + if _, err := os.Stat(snapshotDir); !os.IsNotExist(err) { + t.Fatalf("snapshot dir should be removed, stat error: %v", err) + } + if _, err := fileStore.GetSnapshot(context.Background(), "snap-delete"); err != store.ErrNotFound { + t.Fatalf("snapshot should be removed from store, got: %v", err) + } +} + func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { root := t.TempDir() cfg := testConfig(root) diff --git a/internal/daemon/files.go b/internal/daemon/files.go index 953640b..ad2ede4 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -523,7 +523,7 @@ func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guest if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil { return fmt.Errorf("write guest ssh host private key staging file: %w", err) } - if err := replaceExt4File(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key"); err != nil { + if err := replaceExt4FileMode(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key", "100600"); err != nil { return err } @@ -531,7 +531,7 @@ func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guest if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil { return fmt.Errorf("write guest ssh host public key staging file: %w", err) } - if err := replaceExt4File(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub"); err != nil { + if err := replaceExt4FileMode(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub", "100644"); err != nil { return err } @@ -543,6 +543,7 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont if machineName == "" { return fmt.Errorf("machine_id is required") } + hostname := "agentcomputer" stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "machine-identity-*") if err != nil { @@ -553,11 +554,11 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont }() identityFiles := map[string]string{ - "/etc/microagent/machine-name": machineName + "\n", - "/etc/hostname": machineName + "\n", + "/etc/microagent/machine-name": hostname + "\n", + "/etc/hostname": hostname + "\n", "/etc/hosts": fmt.Sprintf( "127.0.0.1 localhost\n127.0.1.1 %s\n::1 localhost ip6-localhost ip6-loopback\nff02::1 ip6-allnodes\nff02::2 ip6-allrouters\n", - machineName, + hostname, ), } @@ -576,10 +577,19 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont } func replaceExt4File(ctx context.Context, imagePath string, sourcePath string, targetPath string) error { + return replaceExt4FileMode(ctx, imagePath, sourcePath, targetPath, "") +} + +func replaceExt4FileMode(ctx context.Context, imagePath string, sourcePath string, targetPath string, mode string) error { _ = runDebugFS(ctx, imagePath, fmt.Sprintf("rm %s", targetPath)) if err := runDebugFS(ctx, imagePath, fmt.Sprintf("write %s %s", sourcePath, targetPath)); err != nil { return fmt.Errorf("inject %q into %q: %w", targetPath, imagePath, err) } + if mode != "" { + if err := runDebugFS(ctx, imagePath, fmt.Sprintf("set_inode_field %s mode 0%s", targetPath, mode)); err != nil { + return fmt.Errorf("set mode on %q in %q: %w", targetPath, imagePath, err) + } + } return nil } diff --git a/internal/daemon/guest_config_test.go b/internal/daemon/guest_config_test.go index 08d9796..7d8cc80 100644 --- a/internal/daemon/guest_config_test.go +++ b/internal/daemon/guest_config_test.go @@ -81,15 +81,15 @@ func TestInjectMachineIdentityWritesHostnameFiles(t *testing.T) { if err != nil { t.Fatalf("read hostname: %v", err) } - if hostname != "kiruru\n" { - t.Fatalf("hostname mismatch: got %q want %q", hostname, "kiruru\n") + if hostname != "agentcomputer\n" { + t.Fatalf("hostname mismatch: got %q want %q", hostname, "agentcomputer\n") } hosts, err := readExt4File(imagePath, "/etc/hosts") if err != nil { t.Fatalf("read hosts: %v", err) } - if !strings.Contains(hosts, "127.0.1.1 kiruru") { + if !strings.Contains(hosts, "127.0.1.1 agentcomputer") { t.Fatalf("hosts missing machine name: %q", hosts) } } diff --git a/internal/daemon/guest_personalization.go b/internal/daemon/guest_personalization.go index 9ff825f..c0402cc 100644 --- a/internal/daemon/guest_personalization.go +++ b/internal/daemon/guest_personalization.go @@ -23,7 +23,7 @@ const ( defaultGuestPersonalizationVsockID = "microagent-personalizer" defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock" defaultGuestPersonalizationVsockPort = uint32(1024) - defaultGuestPersonalizationTimeout = 15 * time.Second + defaultGuestPersonalizationTimeout = 30 * time.Second guestPersonalizationRetryInterval = 100 * time.Millisecond minGuestVsockCID = uint32(3) maxGuestVsockCID = uint32(1<<31 - 1) @@ -91,9 +91,34 @@ func sendGuestPersonalization(ctx context.Context, state firecracker.MachineStat if err != nil { return nil, err } - connection, err := dialGuestPersonalization(ctx, vsockPath) + + var lastErr error + for { + if ctx.Err() != nil { + if lastErr != nil { + return nil, lastErr + } + return nil, ctx.Err() + } + + resp, err := tryGuestPersonalization(ctx, vsockPath, payloadBytes) + if err == nil { + return resp, nil + } + lastErr = err + + select { + case <-ctx.Done(): + return nil, lastErr + case <-time.After(guestPersonalizationRetryInterval): + } + } +} + +func tryGuestPersonalization(ctx context.Context, vsockPath string, payloadBytes []byte) (*guestPersonalizationResponse, error) { + connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath) if err != nil { - return nil, err + return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err) } defer func() { _ = connection.Close() @@ -140,25 +165,3 @@ func setConnectionDeadline(ctx context.Context, connection net.Conn) { } _ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout)) } - -func dialGuestPersonalization(ctx context.Context, vsockPath string) (net.Conn, error) { - dialer := &net.Dialer{} - for { - connection, err := dialer.DialContext(ctx, "unix", vsockPath) - if err == nil { - return connection, nil - } - if ctx.Err() != nil { - return nil, ctx.Err() - } - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err) - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(guestPersonalizationRetryInterval): - } - } -} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index 9ee5487..f8c870e 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -52,6 +52,9 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil } if record.Phase == contracthost.MachinePhaseStarting { + // reconcileMachine acquires the machine lock, so we must release + // ours first to avoid self-deadlock. + unlock() reconciled, err := d.reconcileMachine(ctx, id) if err != nil { return nil, err @@ -220,6 +223,12 @@ func (d *Daemon) Reconcile(ctx context.Context) error { return err } for _, operation := range operations { + unlock, ok := d.tryLockMachine(operation.MachineID) + if !ok { + continue + } + unlock() + switch operation.Type { case model.MachineOperationCreate: if err := d.reconcileCreate(ctx, operation.MachineID); err != nil { diff --git a/internal/daemon/readiness.go b/internal/daemon/readiness.go deleted file mode 100644 index a1a629b..0000000 --- a/internal/daemon/readiness.go +++ /dev/null @@ -1,46 +0,0 @@ -package daemon - -import ( - "context" - "fmt" - "net" - "strconv" - "strings" - - contracthost "github.com/getcompanion-ai/computer-host/contract" -) - -func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) { - host = strings.TrimSpace(host) - if host == "" { - return false, fmt.Errorf("guest runtime host is required") - } - - for _, port := range ports { - probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout) - ready, err := guestPortReady(probeCtx, host, port) - cancel() - if err != nil { - return false, err - } - if !ready { - return false, nil - } - } - return true, nil -} - -func guestPortReady(ctx context.Context, host string, port contracthost.MachinePort) (bool, error) { - address := net.JoinHostPort(host, strconv.Itoa(int(port.Port))) - dialer := net.Dialer{Timeout: defaultGuestDialTimeout} - - connection, err := dialer.DialContext(ctx, string(port.Protocol), address) - if err == nil { - _ = connection.Close() - return true, nil - } - if ctx.Err() != nil { - return false, nil - } - return false, nil -} diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index eb4eae0..fba072f 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/hex" "errors" - "fmt" "net" "os" "path/filepath" @@ -57,19 +56,6 @@ func (s machineLookupErrorStore) GetMachine(context.Context, contracthost.Machin return nil, s.err } -type relayExhaustionStore struct { - hoststore.Store - extraMachines []model.MachineRecord -} - -func (s relayExhaustionStore) ListMachines(ctx context.Context) ([]model.MachineRecord, error) { - machines, err := s.Store.ListMachines(ctx) - if err != nil { - return nil, err - } - return append(machines, s.extraMachines...), nil -} - type publishedPortResult struct { response *contracthost.CreatePublishedPortResponse err error @@ -255,6 +241,52 @@ func TestReconcileSnapshotPreservesArtifactsOnUnexpectedStoreError(t *testing.T) assertOperationCount(t, baseStore, 1) } +func TestReconcileSkipsInFlightSnapshotOperationWhileMachineLocked(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, baseStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + stubGuestSSHPublicKeyReader(hostDaemon) + + snapshotID := contracthost.SnapshotID("snap-inflight") + operation := model.OperationRecord{ + MachineID: "vm-1", + Type: model.MachineOperationSnapshot, + StartedAt: time.Now().UTC(), + SnapshotID: &snapshotID, + } + if err := baseStore.UpsertOperation(context.Background(), operation); err != nil { + t.Fatalf("upsert operation: %v", err) + } + + snapshotDir := filepath.Join(cfg.SnapshotsDir, string(snapshotID)) + if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + markerPath := filepath.Join(snapshotDir, "keep.txt") + if err := os.WriteFile(markerPath, []byte("keep"), 0o644); err != nil { + t.Fatalf("write marker file: %v", err) + } + + unlock := hostDaemon.lockMachine("vm-1") + defer unlock() + + if err := hostDaemon.Reconcile(context.Background()); err != nil { + t.Fatalf("Reconcile returned error: %v", err) + } + if _, statErr := os.Stat(markerPath); statErr != nil { + t.Fatalf("in-flight snapshot artifacts should be preserved, stat error: %v", statErr) + } + assertOperationCount(t, baseStore, 1) +} + func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) { root := t.TempDir() cfg := testConfig(root) @@ -307,11 +339,6 @@ func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) { t.Fatalf("create file store: %v", err) } - exhaustedStore := relayExhaustionStore{ - Store: baseStore, - extraMachines: exhaustedMachineRelayRecords(), - } - sshListener := listenTestPort(t, int(defaultSSHPort)) defer func() { _ = sshListener.Close() @@ -334,7 +361,7 @@ func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) { }, } - hostDaemon, err := New(cfg, exhaustedStore, runtime) + hostDaemon, err := New(cfg, baseStore, runtime) if err != nil { t.Fatalf("create daemon: %v", err) } @@ -416,11 +443,6 @@ func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) { t.Fatalf("create file store: %v", err) } - exhaustedStore := relayExhaustionStore{ - Store: baseStore, - extraMachines: exhaustedMachineRelayRecords(), - } - startedAt := time.Unix(1700000300, 0).UTC() runtime := &fakeRuntime{ bootState: firecracker.MachineState{ @@ -434,7 +456,7 @@ func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) { }, } - hostDaemon, err := New(cfg, exhaustedStore, runtime) + hostDaemon, err := New(cfg, baseStore, runtime) if err != nil { t.Fatalf("create daemon: %v", err) } @@ -909,19 +931,6 @@ func waitPublishedPortResult(t *testing.T, ch <-chan publishedPortResult) publis } } -func exhaustedMachineRelayRecords() []model.MachineRecord { - count := int(maxMachineSSHRelayPort-minMachineSSHRelayPort) + 1 - machines := make([]model.MachineRecord, 0, count) - for i := 0; i < count; i++ { - machines = append(machines, model.MachineRecord{ - ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)), - Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i), 0), - Phase: contracthost.MachinePhaseRunning, - }) - } - return machines -} - func mustSHA256Hex(t *testing.T, payload []byte) string { t.Helper() diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index fcb67e4..59764c6 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -489,7 +489,10 @@ func (d *Daemon) DeleteSnapshotByID(ctx context.Context, snapshotID contracthost if err != nil { return err } - snapshotDir := filepath.Dir(snap.MemFilePath) + snapshotDir, ok := snapshotDirectory(*snap) + if !ok { + return fmt.Errorf("snapshot %q has no local artifact directory", snapshotID) + } if err := os.RemoveAll(snapshotDir); err != nil { return fmt.Errorf("remove snapshot dir %q: %w", snapshotDir, err) } @@ -520,6 +523,25 @@ func snapshotArtifactsToContract(artifacts []model.SnapshotArtifactRecord) []con return converted } +func snapshotDirectory(snapshot model.SnapshotRecord) (string, bool) { + for _, artifact := range snapshot.Artifacts { + if path := strings.TrimSpace(artifact.LocalPath); path != "" { + return filepath.Dir(path), true + } + } + for _, diskPath := range snapshot.DiskPaths { + if path := strings.TrimSpace(diskPath); path != "" { + return filepath.Dir(path), true + } + } + for _, legacyPath := range []string{snapshot.MemFilePath, snapshot.StateFilePath} { + if path := strings.TrimSpace(legacyPath); path != "" { + return filepath.Dir(path), true + } + } + return "", false +} + func orderedRestoredUserDiskArtifacts(artifacts map[string]restoredSnapshotArtifact) []restoredSnapshotArtifact { ordered := make([]restoredSnapshotArtifact, 0, len(artifacts)) for name, artifact := range artifacts { diff --git a/internal/daemon/startup.go b/internal/daemon/startup.go index b141618..10cf994 100644 --- a/internal/daemon/startup.go +++ b/internal/daemon/startup.go @@ -73,5 +73,12 @@ func (d *Daemon) completeMachineStartup(ctx context.Context, record *model.Machi if err := d.store.UpdateMachine(ctx, *record); err != nil { return nil, err } + if err := d.ensureMachineRelays(ctx, record); err != nil { + return d.failMachineStartup(ctx, record, err.Error()) + } + if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { + d.stopMachineRelays(record.ID) + return d.failMachineStartup(ctx, record, err.Error()) + } return record, nil }