diff --git a/internal/daemon/create.go b/internal/daemon/create.go index b5c54f3..28bf8f9 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -5,8 +5,11 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" + "golang.org/x/sync/errgroup" + "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" @@ -52,13 +55,34 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi } }() - artifact, err := d.ensureArtifact(ctx, req.Artifact) - if err != nil { - return nil, err - } - - userVolumes, err := d.loadAttachableUserVolumes(ctx, req.MachineID, req.UserVolumeIDs) - if err != nil { + var ( + artifact *model.ArtifactRecord + userVolumes []model.VolumeRecord + guestHostKey *guestSSHHostKeyPair + readyNonce string + ) + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + var err error + artifact, err = d.ensureArtifact(groupCtx, req.Artifact) + return err + }) + group.Go(func() error { + var err error + userVolumes, err = d.loadAttachableUserVolumes(groupCtx, req.MachineID, req.UserVolumeIDs) + return err + }) + group.Go(func() error { + var err error + guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx) + return err + }) + group.Go(func() error { + var err error + readyNonce, err = newGuestReadyNonce() + return err + }) + if err := group.Wait(); err != nil { return nil, err } @@ -86,8 +110,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil { return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err) } + if err := injectGuestSSHHostKey(ctx, systemVolumePath, guestHostKey); err != nil { + return nil, fmt.Errorf("inject guest ssh host key for %q: %w", req.MachineID, err) + } - spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig) + spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig, readyNonce) if err != nil { return nil, err } @@ -135,19 +162,21 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi } record := model.MachineRecord{ - ID: req.MachineID, - Artifact: req.Artifact, - GuestConfig: cloneGuestConfig(guestConfig), - SystemVolumeID: systemVolumeRecord.ID, - UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), - RuntimeHost: state.RuntimeHost, - TapDevice: state.TapName, - Ports: defaultMachinePorts(), - Phase: contracthost.MachinePhaseStarting, - PID: state.PID, - SocketPath: state.SocketPath, - CreatedAt: now, - StartedAt: state.StartedAt, + ID: req.MachineID, + Artifact: req.Artifact, + GuestConfig: cloneGuestConfig(guestConfig), + SystemVolumeID: systemVolumeRecord.ID, + UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), + RuntimeHost: state.RuntimeHost, + TapDevice: state.TapName, + Ports: defaultMachinePorts(), + GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey), + GuestReadyNonce: readyNonce, + Phase: contracthost.MachinePhaseStarting, + PID: state.PID, + SocketPath: state.SocketPath, + CreatedAt: now, + StartedAt: state.StartedAt, } if err := d.store.CreateMachine(ctx, record); err != nil { for _, volume := range userVolumes { @@ -159,12 +188,17 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi return nil, err } + recordReady, err := d.completeMachineStartup(ctx, &record, *state) + if err != nil { + return nil, err + } + removeSystemVolumeOnFailure = false clearOperation = true - return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil + return &contracthost.CreateMachineResponse{Machine: machineToContract(*recordReady)}, nil } -func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) { +func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig, readyNonce string) (firecracker.MachineSpec, error) { drives := make([]firecracker.DriveSpec, 0, len(userVolumes)) for i, volume := range userVolumes { drives = append(drives, firecracker.DriveSpec{ @@ -176,7 +210,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo }) } - mmds, err := d.guestMetadataSpec(machineID, guestConfig) + mmds, err := d.guestMetadataSpec(machineID, guestConfig, readyNonce) if err != nil { return firecracker.MachineSpec{}, err } @@ -221,10 +255,14 @@ func (d *Daemon) ensureArtifact(ctx context.Context, ref contracthost.ArtifactRe kernelPath := filepath.Join(dir, "kernel") rootFSPath := filepath.Join(dir, "rootfs") - if err := downloadFile(ctx, ref.KernelImageURL, kernelPath); err != nil { - return nil, err - } - if err := downloadFile(ctx, ref.RootFSURL, rootFSPath); err != nil { + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + return downloadFile(groupCtx, ref.KernelImageURL, kernelPath) + }) + group.Go(func() error { + return downloadFile(groupCtx, ref.RootFSURL, rootFSPath) + }) + if err := group.Wait(); err != nil { return nil, err } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 64190b2..8e457d9 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -50,7 +50,7 @@ type Daemon struct { injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error syncGuestFilesystem func(context.Context, string) error shutdownGuest func(context.Context, string) error - personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error + personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) locksMu sync.Mutex machineLocks map[contracthost.MachineID]*sync.Mutex diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index c0d9a60..a500ea8 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -152,7 +152,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { t.Fatalf("create machine: %v", err) } - if response.Machine.Phase != contracthost.MachinePhaseStarting { + if response.Machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase) } if response.Machine.RuntimeHost != "127.0.0.1" { @@ -230,7 +230,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if machine.SystemVolumeID != "vm-1-system" { t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID) } - if machine.Phase != contracthost.MachinePhaseStarting { + if machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("stored machine phase mismatch: got %q", machine.Phase) } if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 { @@ -401,7 +401,7 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) { t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") }) personalized := false - hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error { + hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) { personalized = true if record.ID != "vm-starting" { t.Fatalf("personalized machine mismatch: got %q", record.ID) @@ -409,9 +409,15 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) { if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 { t.Fatalf("personalized state mismatch: %#v", state) } - return nil + guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey) + if guestSSHPublicKey == "" { + guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3" + } + return &guestReadyResult{ + ReadyNonce: record.GuestReadyNonce, + GuestSSHPublicKey: guestSSHPublicKey, + }, nil } - stubGuestSSHPublicKeyReader(hostDaemon) if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ ID: "vm-starting", @@ -455,9 +461,9 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } - hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error { + hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) { t.Fatalf("ListMachines should not reconcile guest personalization") - return nil + return nil, nil } hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) { t.Fatalf("ListMachines should not read guest ssh public key") @@ -492,7 +498,7 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) { } } -func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) { +func TestReconcileStartingMachineFailsWhenHandshakeFails(t *testing.T) { root := t.TempDir() cfg := testConfig(root) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -511,11 +517,8 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) { if err != nil { t.Fatalf("create daemon: %v", err) } - hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error { - return errors.New("vsock EOF") - } - hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) { - return "", errors.New("Permission denied") + hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) { + return nil, errors.New("vsock EOF") } if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{ @@ -542,14 +545,14 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) { if err != nil { t.Fatalf("get machine: %v", err) } - if record.Phase != contracthost.MachinePhaseRunning { - t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseRunning) + if record.Phase != contracthost.MachinePhaseFailed { + t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseFailed) } - if record.GuestSSHPublicKey != "ssh-ed25519 AAAAExistingHostKey" { - t.Fatalf("guest ssh public key = %q, want preserved value", record.GuestSSHPublicKey) + if !strings.Contains(record.Error, "vsock EOF") { + t.Fatalf("failure reason = %q, want vsock error", record.Error) } - if len(runtime.deleteCalls) != 0 { - t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls)) + if len(runtime.deleteCalls) != 1 { + t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls)) } } @@ -756,7 +759,7 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) { if response.Machine.ID != "restored" { t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) } - if response.Machine.Phase != contracthost.MachinePhaseStarting { + if response.Machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) } if runtime.bootCalls != 1 { @@ -1013,7 +1016,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { if response.Machine.ID != "restored" { t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID) } - if response.Machine.Phase != contracthost.MachinePhaseStarting { + if response.Machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) } if runtime.bootCalls != 1 { @@ -1033,7 +1036,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { if err != nil { t.Fatalf("get restored machine: %v", err) } - if machine.Phase != contracthost.MachinePhaseStarting { + if machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("restored machine phase mismatch: got %q", machine.Phase) } if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" { @@ -1126,7 +1129,7 @@ func TestRestoreSnapshotBootsWithFreshNetworkWhenSourceNetworkInUseOnHost(t *tes if err != nil { t.Fatalf("restore snapshot error = %v, want success", err) } - if response.Machine.Phase != contracthost.MachinePhaseStarting { + if response.Machine.Phase != contracthost.MachinePhaseRunning { t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase) } if runtime.bootCalls != 1 { @@ -1254,8 +1257,15 @@ func TestGuestKernelArgsRemovesPCIOffWhenPCIEnabled(t *testing.T) { } func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) { - hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) { - return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil + hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, _ firecracker.MachineState) (*guestReadyResult, error) { + guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey) + if guestSSHPublicKey == "" { + guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3" + } + return &guestReadyResult{ + ReadyNonce: record.GuestReadyNonce, + GuestSSHPublicKey: guestSSHPublicKey, + }, nil } } diff --git a/internal/daemon/files.go b/internal/daemon/files.go index 91960d3..1a5f741 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -348,6 +348,42 @@ func (d *Daemon) writeBackendSSHPublicKey(privateKeyPath string, publicKeyPath s return nil } +type guestSSHHostKeyPair struct { + PrivateKey []byte + PublicKey string +} + +func generateGuestSSHHostKeyPair(ctx context.Context) (*guestSSHHostKeyPair, error) { + stagingDir, err := os.MkdirTemp("", "guest-ssh-hostkey-*") + if err != nil { + return nil, fmt.Errorf("create guest ssh host key staging dir: %w", err) + } + defer func() { + _ = os.RemoveAll(stagingDir) + }() + + privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key") + command := exec.CommandContext(ctx, "ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-C", "", "-f", privateKeyPath) + output, err := command.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("generate guest ssh host keypair: %w: %s", err, strings.TrimSpace(string(output))) + } + + privateKey, err := os.ReadFile(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("read guest ssh host private key %q: %w", privateKeyPath, err) + } + publicKey, err := os.ReadFile(privateKeyPath + ".pub") + if err != nil { + return nil, fmt.Errorf("read guest ssh host public key %q: %w", privateKeyPath+".pub", err) + } + + return &guestSSHHostKeyPair{ + PrivateKey: privateKey, + PublicKey: strings.TrimSpace(string(publicKey)), + }, nil +} + func fileExists(path string) bool { _, err := os.Stat(path) return err == nil @@ -441,6 +477,41 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho return nil } +func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guestSSHHostKeyPair) error { + if keyPair == nil { + return fmt.Errorf("guest ssh host keypair is required") + } + if strings.TrimSpace(keyPair.PublicKey) == "" { + return fmt.Errorf("guest ssh host public key is required") + } + + stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "guest-ssh-hostkey-*") + if err != nil { + return fmt.Errorf("create guest ssh host key staging dir: %w", err) + } + defer func() { + _ = os.RemoveAll(stagingDir) + }() + + privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key") + if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil { + return fmt.Errorf("write guest ssh host private key staging file: %w", err) + } + if err := replaceExt4File(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key"); err != nil { + return err + } + + publicKeyPath := privateKeyPath + ".pub" + if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil { + return fmt.Errorf("write guest ssh host public key staging file: %w", err) + } + if err := replaceExt4File(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub"); err != nil { + return err + } + + return nil +} + func injectMachineIdentity(ctx context.Context, imagePath string, machineID contracthost.MachineID) error { machineName := strings.TrimSpace(string(machineID)) if machineName == "" { diff --git a/internal/daemon/guest_identity.go b/internal/daemon/guest_identity.go index 96e92e8..7adc820 100644 --- a/internal/daemon/guest_identity.go +++ b/internal/daemon/guest_identity.go @@ -20,7 +20,7 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos if machineName == "" { return fmt.Errorf("machine id is required") } - mmds, err := d.guestMetadataSpec(machineID, guestConfig) + mmds, err := d.guestMetadataSpec(machineID, guestConfig, "") if err != nil { return err } diff --git a/internal/daemon/guest_metadata.go b/internal/daemon/guest_metadata.go index 2df5a9a..fd95222 100644 --- a/internal/daemon/guest_metadata.go +++ b/internal/daemon/guest_metadata.go @@ -25,6 +25,7 @@ type guestMetadataPayload struct { Version string `json:"version"` MachineID string `json:"machine_id"` Hostname string `json:"hostname"` + ReadyNonce string `json:"ready_nonce,omitempty"` AuthorizedKeys []string `json:"authorized_keys,omitempty"` TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"` LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"` @@ -55,7 +56,7 @@ func guestHostname(machineID contracthost.MachineID, guestConfig *contracthost.G return strings.TrimSpace(string(machineID)) } -func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) { +func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig, readyNonce string) (*firecracker.MMDSSpec, error) { name := guestHostname(machineID, guestConfig) if name == "" { return nil, fmt.Errorf("machine id is required") @@ -67,6 +68,7 @@ func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig Version: defaultMMDSPayloadVersion, MachineID: name, Hostname: name, + ReadyNonce: strings.TrimSpace(readyNonce), AuthorizedKeys: nil, TrustedUserCAKeys: nil, }, diff --git a/internal/daemon/guest_personalization.go b/internal/daemon/guest_personalization.go index b594cbd..9ff825f 100644 --- a/internal/daemon/guest_personalization.go +++ b/internal/daemon/guest_personalization.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/binary" "encoding/json" + "errors" "fmt" "net" "path/filepath" @@ -13,20 +14,32 @@ import ( "strings" "time" - contracthost "github.com/getcompanion-ai/computer-host/contract" "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" ) const ( defaultGuestPersonalizationVsockID = "microagent-personalizer" defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock" defaultGuestPersonalizationVsockPort = uint32(1024) - defaultGuestPersonalizationTimeout = 2 * time.Second + defaultGuestPersonalizationTimeout = 15 * time.Second + guestPersonalizationRetryInterval = 100 * time.Millisecond minGuestVsockCID = uint32(3) maxGuestVsockCID = uint32(1<<31 - 1) ) +type guestPersonalizationResponse struct { + Status string `json:"status"` + ReadyNonce string `json:"ready_nonce,omitempty"` + GuestSSHPublicKey string `json:"guest_ssh_public_key,omitempty"` + Error string `json:"error,omitempty"` +} + +type guestReadyRequest struct { + ReadyNonce string `json:"ready_nonce,omitempty"` +} + func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec { return &firecracker.VsockSpec{ ID: defaultGuestPersonalizationVsockID, @@ -41,53 +54,46 @@ func guestVsockCID(machineID contracthost.MachineID) uint32 { return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space } -func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error { +func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) { if record == nil { - return fmt.Errorf("machine record is required") + return nil, fmt.Errorf("machine record is required") } personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout) defer cancel() - mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig) + response, err := sendGuestPersonalization(personalizeCtx, state, guestReadyRequest{ + ReadyNonce: strings.TrimSpace(record.GuestReadyNonce), + }) if err != nil { - return err + return nil, fmt.Errorf("wait for guest ready over vsock: %w", err) } - envelope, ok := mmds.Data.(guestMetadataEnvelope) - if !ok { - return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data) + if !strings.EqualFold(strings.TrimSpace(response.Status), "ok") { + message := strings.TrimSpace(response.Error) + if message == "" { + message = fmt.Sprintf("unexpected guest personalization status %q", strings.TrimSpace(response.Status)) + } + return nil, errors.New(message) } - - if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil { - return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err)) - } - if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil { - return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err)) - } - return nil + return &guestReadyResult{ + ReadyNonce: strings.TrimSpace(response.ReadyNonce), + GuestSSHPublicKey: strings.TrimSpace(response.GuestSSHPublicKey), + }, nil } -func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error { - fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig) - if fallbackErr == nil { - return nil - } - return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr) -} - -func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error { +func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestReadyRequest) (*guestPersonalizationResponse, error) { payloadBytes, err := json.Marshal(payload) if err != nil { - return fmt.Errorf("marshal guest personalization payload: %w", err) + return nil, fmt.Errorf("marshal guest personalization payload: %w", err) } vsockPath, err := guestVsockHostPath(state) if err != nil { - return err + return nil, err } - connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath) + connection, err := dialGuestPersonalization(ctx, vsockPath) if err != nil { - return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err) + return nil, err } defer func() { _ = connection.Close() @@ -96,27 +102,28 @@ func sendGuestPersonalization(ctx context.Context, state firecracker.MachineStat reader := bufio.NewReader(connection) if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil { - return fmt.Errorf("write vsock connect request: %w", err) + return nil, fmt.Errorf("write vsock connect request: %w", err) } response, err := reader.ReadString('\n') if err != nil { - return fmt.Errorf("read vsock connect response: %w", err) + return nil, fmt.Errorf("read vsock connect response: %w", err) } if !strings.HasPrefix(strings.TrimSpace(response), "OK ") { - return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response)) + return nil, fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response)) } if _, err := connection.Write(append(payloadBytes, '\n')); err != nil { - return fmt.Errorf("write guest personalization payload: %w", err) + return nil, fmt.Errorf("write guest personalization payload: %w", err) } response, err = reader.ReadString('\n') if err != nil { - return fmt.Errorf("read guest personalization response: %w", err) + return nil, fmt.Errorf("read guest personalization response: %w", err) } - if strings.TrimSpace(response) != "OK" { - return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response)) + var payloadResponse guestPersonalizationResponse + if err := json.Unmarshal([]byte(strings.TrimSpace(response)), &payloadResponse); err != nil { + return nil, fmt.Errorf("decode guest personalization response %q: %w", strings.TrimSpace(response), err) } - return nil + return &payloadResponse, nil } func guestVsockHostPath(state firecracker.MachineState) (string, error) { @@ -133,3 +140,25 @@ func setConnectionDeadline(ctx context.Context, connection net.Conn) { } _ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout)) } + +func dialGuestPersonalization(ctx context.Context, vsockPath string) (net.Conn, error) { + dialer := &net.Dialer{} + for { + connection, err := dialer.DialContext(ctx, "unix", vsockPath) + if err == nil { + return connection, nil + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(guestPersonalizationRetryInterval): + } + } +} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index beef737..9ee5487 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "golang.org/x/sync/errgroup" + "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" @@ -50,7 +52,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil } if record.Phase == contracthost.MachinePhaseStarting { - return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil + reconciled, err := d.reconcileMachine(ctx, id) + if err != nil { + return nil, err + } + return &contracthost.GetMachineResponse{Machine: machineToContract(*reconciled)}, nil } if record.Phase != contracthost.MachinePhaseStopped { return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase) @@ -71,21 +77,38 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* } }() - systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID) - if err != nil { - return nil, err - } - artifact, err := d.store.GetArtifact(ctx, record.Artifact) - if err != nil { - return nil, err - } - userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs) - if err != nil { + var ( + systemVolume *model.VolumeRecord + artifact *model.ArtifactRecord + userVolumes []model.VolumeRecord + readyNonce string + ) + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + var err error + systemVolume, err = d.store.GetVolume(groupCtx, record.SystemVolumeID) + return err + }) + group.Go(func() error { + var err error + artifact, err = d.store.GetArtifact(groupCtx, record.Artifact) + return err + }) + group.Go(func() error { + var err error + userVolumes, err = d.loadAttachableUserVolumes(groupCtx, id, record.UserVolumeIDs) + return err + }) + group.Go(func() error { + var err error + readyNonce, err = newGuestReadyNonce() + return err + }) + if err := group.Wait(); err != nil { return nil, err } repairDirtyFilesystem(systemVolume.Path) - - spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig) + spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig, readyNonce) if err != nil { return nil, err } @@ -100,7 +123,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* record.RuntimeHost = state.RuntimeHost record.TapDevice = state.TapName record.Ports = defaultMachinePorts() - record.GuestSSHPublicKey = "" + record.GuestReadyNonce = readyNonce record.Phase = contracthost.MachinePhaseStarting record.Error = "" record.PID = state.PID @@ -112,6 +135,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (* return nil, err } + record, err = d.completeMachineStartup(ctx, record, *state) + if err != nil { + return nil, err + } + clearOperation = true return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil } @@ -376,44 +404,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma return nil, err } if record.Phase == contracthost.MachinePhaseStarting { - if state.Phase != firecracker.PhaseRunning { - return d.failMachineStartup(ctx, record, state.Error) - } - ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts()) - if err != nil { - return nil, err - } - if !ready { - return record, nil - } - if err := d.personalizeGuest(ctx, record, *state); err != nil { - fmt.Fprintf(os.Stderr, "warning: guest personalization for %q failed: %v\n", record.ID, err) - } - guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost) - if err != nil { - fmt.Fprintf(os.Stderr, "warning: read guest ssh public key for %q failed: %v\n", record.ID, err) - guestSSHPublicKey = record.GuestSSHPublicKey - } - record.RuntimeHost = state.RuntimeHost - record.TapDevice = state.TapName - record.Ports = defaultMachinePorts() - record.GuestSSHPublicKey = guestSSHPublicKey - record.Phase = contracthost.MachinePhaseRunning - record.Error = "" - record.PID = state.PID - record.SocketPath = state.SocketPath - record.StartedAt = state.StartedAt - if err := d.store.UpdateMachine(ctx, *record); err != nil { - return nil, err - } - if err := d.ensureMachineRelays(ctx, record); err != nil { - return d.failMachineStartup(ctx, record, err.Error()) - } - if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { - d.stopMachineRelays(record.ID) - return d.failMachineStartup(ctx, record, err.Error()) - } - return record, nil + return d.completeMachineStartup(ctx, record, *state) } if state.Phase == firecracker.PhaseRunning { if err := d.ensureMachineRelays(ctx, record); err != nil { @@ -450,7 +441,7 @@ func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRe record.Phase = contracthost.MachinePhaseFailed record.Error = strings.TrimSpace(failureReason) record.Ports = defaultMachinePorts() - record.GuestSSHPublicKey = "" + record.GuestReadyNonce = "" record.PID = 0 record.SocketPath = "" record.RuntimeHost = "" @@ -511,6 +502,7 @@ func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRec record.Phase = contracthost.MachinePhaseStopped record.Error = "" + record.GuestReadyNonce = "" record.PID = 0 record.SocketPath = "" record.RuntimeHost = "" diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index b346706..3fe8361 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -299,7 +299,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T) assertOperationCount(t, baseStore, 1) } -func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) { +func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) { root := t.TempDir() cfg := testConfig(root) baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -386,16 +386,16 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) { if err != nil { t.Fatalf("StartMachine error = %v", err) } - if response.Machine.Phase != contracthost.MachinePhaseStarting { - t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting) + if response.Machine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning) } machine, err := baseStore.GetMachine(context.Background(), "vm-start") if err != nil { t.Fatalf("get machine: %v", err) } - if machine.Phase != contracthost.MachinePhaseStarting { - t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting) + if machine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseRunning) } if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" { t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice) @@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) { } } -func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) { +func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) { root := t.TempDir() cfg := testConfig(root) baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath) @@ -510,8 +510,8 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T if err != nil { t.Fatalf("RestoreSnapshot returned error: %v", err) } - if response.Machine.Phase != contracthost.MachinePhaseStarting { - t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting) + if response.Machine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning) } if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil { t.Fatalf("restored system volume record should exist: %v", err) diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index f1fb276..fcb67e4 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -245,10 +245,33 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn return nil, err } defer cleanupRestoreArtifacts() - artifact, err := d.ensureArtifact(ctx, req.Artifact) - if err != nil { + var ( + artifact *model.ArtifactRecord + guestHostKey *guestSSHHostKeyPair + readyNonce string + ) + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + var err error + artifact, err = d.ensureArtifact(groupCtx, req.Artifact) + if err != nil { + return fmt.Errorf("ensure artifact for restore: %w", err) + } + return nil + }) + group.Go(func() error { + var err error + guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx) + return err + }) + group.Go(func() error { + var err error + readyNonce, err = newGuestReadyNonce() + return err + }) + if err := group.Wait(); err != nil { clearOperation = true - return nil, fmt.Errorf("ensure artifact for restore: %w", err) + return nil, err } // COW-copy system disk from snapshot to new machine's disk dir. @@ -280,6 +303,10 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn clearOperation = true return nil, fmt.Errorf("inject guest config for restore: %w", err) } + if err := injectGuestSSHHostKey(ctx, newSystemDiskPath, guestHostKey); err != nil { + clearOperation = true + return nil, fmt.Errorf("inject guest ssh host key for restore: %w", err) + } type restoredUserVolume struct { ID contracthost.VolumeID @@ -313,7 +340,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn Path: volume.Path, }) } - spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig) + spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig, readyNonce) if err != nil { clearOperation = true return nil, fmt.Errorf("build machine spec for restore: %w", err) @@ -376,7 +403,8 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn RuntimeHost: machineState.RuntimeHost, TapDevice: machineState.TapName, Ports: defaultMachinePorts(), - GuestSSHPublicKey: "", + GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey), + GuestReadyNonce: readyNonce, Phase: contracthost.MachinePhaseStarting, PID: machineState.PID, SocketPath: machineState.SocketPath, @@ -393,10 +421,15 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn return nil, err } + record, err := d.completeMachineStartup(ctx, &machineRecord, *machineState) + if err != nil { + return nil, err + } + removeMachineDiskDirOnFailure = false clearOperation = true return &contracthost.RestoreSnapshotResponse{ - Machine: machineToContract(machineRecord), + Machine: machineToContract(*record), }, nil } diff --git a/internal/daemon/startup.go b/internal/daemon/startup.go new file mode 100644 index 0000000..b141618 --- /dev/null +++ b/internal/daemon/startup.go @@ -0,0 +1,77 @@ +package daemon + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +type guestReadyResult struct { + ReadyNonce string + GuestSSHPublicKey string +} + +func newGuestReadyNonce() (string, error) { + var bytes [16]byte + if _, err := rand.Read(bytes[:]); err != nil { + return "", fmt.Errorf("generate guest ready nonce: %w", err) + } + return hex.EncodeToString(bytes[:]), nil +} + +func (d *Daemon) completeMachineStartup(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*model.MachineRecord, error) { + if record == nil { + return nil, fmt.Errorf("machine record is required") + } + if state.Phase != firecracker.PhaseRunning { + failureReason := strings.TrimSpace(state.Error) + if failureReason == "" { + failureReason = "machine did not reach running phase" + } + return d.failMachineStartup(ctx, record, failureReason) + } + + ready, err := d.personalizeGuest(ctx, record, state) + if err != nil { + return d.failMachineStartup(ctx, record, err.Error()) + } + + expectedNonce := strings.TrimSpace(record.GuestReadyNonce) + receivedNonce := strings.TrimSpace(ready.ReadyNonce) + if expectedNonce != "" && receivedNonce != expectedNonce { + return d.failMachineStartup(ctx, record, "guest ready nonce mismatch") + } + + expectedGuestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey) + guestSSHPublicKey := strings.TrimSpace(ready.GuestSSHPublicKey) + if guestSSHPublicKey == "" { + if expectedGuestSSHPublicKey == "" { + return d.failMachineStartup(ctx, record, "guest ready response missing ssh host key") + } + guestSSHPublicKey = expectedGuestSSHPublicKey + } + if expectedGuestSSHPublicKey != "" && guestSSHPublicKey != expectedGuestSSHPublicKey { + return d.failMachineStartup(ctx, record, "guest ssh host key mismatch") + } + + record.RuntimeHost = state.RuntimeHost + record.TapDevice = state.TapName + record.Ports = defaultMachinePorts() + record.GuestSSHPublicKey = guestSSHPublicKey + record.GuestReadyNonce = "" + record.Phase = contracthost.MachinePhaseRunning + record.Error = "" + record.PID = state.PID + record.SocketPath = state.SocketPath + record.StartedAt = state.StartedAt + if err := d.store.UpdateMachine(ctx, *record); err != nil { + return nil, err + } + return record, nil +} diff --git a/internal/firecracker/configfile.go b/internal/firecracker/configfile.go index f6a7805..50b11ee 100644 --- a/internal/firecracker/configfile.go +++ b/internal/firecracker/configfile.go @@ -8,15 +8,15 @@ import ( ) type vmConfig struct { - BootSource vmBootSource `json:"boot-source"` - Drives []vmDrive `json:"drives"` - MachineConfig vmMachineConfig `json:"machine-config"` - NetworkInterfaces []vmNetworkIface `json:"network-interfaces"` - Vsock *vmVsock `json:"vsock,omitempty"` - Logger *vmLogger `json:"logger,omitempty"` - MMDSConfig *vmMMDSConfig `json:"mmds-config,omitempty"` - Entropy *vmEntropy `json:"entropy,omitempty"` - Serial *vmSerial `json:"serial,omitempty"` + BootSource vmBootSource `json:"boot-source"` + Drives []vmDrive `json:"drives"` + MachineConfig vmMachineConfig `json:"machine-config"` + NetworkInterfaces []vmNetworkIface `json:"network-interfaces"` + Vsock *vmVsock `json:"vsock,omitempty"` + Logger *vmLogger `json:"logger,omitempty"` + MMDSConfig *vmMMDSConfig `json:"mmds-config,omitempty"` + Entropy *vmEntropy `json:"entropy,omitempty"` + Serial *vmSerial `json:"serial,omitempty"` } type vmBootSource struct { diff --git a/internal/model/types.go b/internal/model/types.go index ad4dd9f..c3c09b3 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -36,6 +36,7 @@ type MachineRecord struct { TapDevice string Ports []contracthost.MachinePort GuestSSHPublicKey string + GuestReadyNonce string Phase contracthost.MachinePhase Error string PID int