feat: remove wakeup path, return on create, host managed ssh-keygen, ack nonce dep

This commit is contained in:
Harivansh Rathi 2026-04-12 20:49:52 +00:00
parent 0e4b18f10b
commit 4a9dc91ebf
13 changed files with 423 additions and 170 deletions

View file

@ -5,8 +5,11 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
@ -52,13 +55,34 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
}
}()
artifact, err := d.ensureArtifact(ctx, req.Artifact)
if err != nil {
return nil, err
}
userVolumes, err := d.loadAttachableUserVolumes(ctx, req.MachineID, req.UserVolumeIDs)
if err != nil {
var (
artifact *model.ArtifactRecord
userVolumes []model.VolumeRecord
guestHostKey *guestSSHHostKeyPair
readyNonce string
)
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
artifact, err = d.ensureArtifact(groupCtx, req.Artifact)
return err
})
group.Go(func() error {
var err error
userVolumes, err = d.loadAttachableUserVolumes(groupCtx, req.MachineID, req.UserVolumeIDs)
return err
})
group.Go(func() error {
var err error
guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx)
return err
})
group.Go(func() error {
var err error
readyNonce, err = newGuestReadyNonce()
return err
})
if err := group.Wait(); err != nil {
return nil, err
}
@ -86,8 +110,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err)
}
if err := injectGuestSSHHostKey(ctx, systemVolumePath, guestHostKey); err != nil {
return nil, fmt.Errorf("inject guest ssh host key for %q: %w", req.MachineID, err)
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig)
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig, readyNonce)
if err != nil {
return nil, err
}
@ -135,19 +162,21 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
}
record := model.MachineRecord{
ID: req.MachineID,
Artifact: req.Artifact,
GuestConfig: cloneGuestConfig(guestConfig),
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseStarting,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
ID: req.MachineID,
Artifact: req.Artifact,
GuestConfig: cloneGuestConfig(guestConfig),
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: defaultMachinePorts(),
GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey),
GuestReadyNonce: readyNonce,
Phase: contracthost.MachinePhaseStarting,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
}
if err := d.store.CreateMachine(ctx, record); err != nil {
for _, volume := range userVolumes {
@ -159,12 +188,17 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
return nil, err
}
recordReady, err := d.completeMachineStartup(ctx, &record, *state)
if err != nil {
return nil, err
}
removeSystemVolumeOnFailure = false
clearOperation = true
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
return &contracthost.CreateMachineResponse{Machine: machineToContract(*recordReady)}, nil
}
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) {
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig, readyNonce string) (firecracker.MachineSpec, error) {
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
for i, volume := range userVolumes {
drives = append(drives, firecracker.DriveSpec{
@ -176,7 +210,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
})
}
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
mmds, err := d.guestMetadataSpec(machineID, guestConfig, readyNonce)
if err != nil {
return firecracker.MachineSpec{}, err
}
@ -221,10 +255,14 @@ func (d *Daemon) ensureArtifact(ctx context.Context, ref contracthost.ArtifactRe
kernelPath := filepath.Join(dir, "kernel")
rootFSPath := filepath.Join(dir, "rootfs")
if err := downloadFile(ctx, ref.KernelImageURL, kernelPath); err != nil {
return nil, err
}
if err := downloadFile(ctx, ref.RootFSURL, rootFSPath); err != nil {
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
return downloadFile(groupCtx, ref.KernelImageURL, kernelPath)
})
group.Go(func() error {
return downloadFile(groupCtx, ref.RootFSURL, rootFSPath)
})
if err := group.Wait(); err != nil {
return nil, err
}

View file

@ -50,7 +50,7 @@ type Daemon struct {
injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error
syncGuestFilesystem func(context.Context, string) error
shutdownGuest func(context.Context, string) error
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error)
locksMu sync.Mutex
machineLocks map[contracthost.MachineID]*sync.Mutex

View file

@ -152,7 +152,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
t.Fatalf("create machine: %v", err)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
}
if response.Machine.RuntimeHost != "127.0.0.1" {
@ -230,7 +230,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
if machine.SystemVolumeID != "vm-1-system" {
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
}
if machine.Phase != contracthost.MachinePhaseStarting {
if machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("stored machine phase mismatch: got %q", machine.Phase)
}
if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 {
@ -401,7 +401,7 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") })
personalized := false
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) {
personalized = true
if record.ID != "vm-starting" {
t.Fatalf("personalized machine mismatch: got %q", record.ID)
@ -409,9 +409,15 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 {
t.Fatalf("personalized state mismatch: %#v", state)
}
return nil
guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
if guestSSHPublicKey == "" {
guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3"
}
return &guestReadyResult{
ReadyNonce: record.GuestReadyNonce,
GuestSSHPublicKey: guestSSHPublicKey,
}, nil
}
stubGuestSSHPublicKeyReader(hostDaemon)
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
ID: "vm-starting",
@ -455,9 +461,9 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error {
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) {
t.Fatalf("ListMachines should not reconcile guest personalization")
return nil
return nil, nil
}
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
t.Fatalf("ListMachines should not read guest ssh public key")
@ -492,7 +498,7 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) {
}
}
func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
func TestReconcileStartingMachineFailsWhenHandshakeFails(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -511,11 +517,8 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
if err != nil {
t.Fatalf("create daemon: %v", err)
}
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error {
return errors.New("vsock EOF")
}
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
return "", errors.New("Permission denied")
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) {
return nil, errors.New("vsock EOF")
}
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
@ -542,14 +545,14 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
if err != nil {
t.Fatalf("get machine: %v", err)
}
if record.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseRunning)
if record.Phase != contracthost.MachinePhaseFailed {
t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseFailed)
}
if record.GuestSSHPublicKey != "ssh-ed25519 AAAAExistingHostKey" {
t.Fatalf("guest ssh public key = %q, want preserved value", record.GuestSSHPublicKey)
if !strings.Contains(record.Error, "vsock EOF") {
t.Fatalf("failure reason = %q, want vsock error", record.Error)
}
if len(runtime.deleteCalls) != 0 {
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
if len(runtime.deleteCalls) != 1 {
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
}
}
@ -756,7 +759,7 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
if response.Machine.ID != "restored" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
}
if runtime.bootCalls != 1 {
@ -1013,7 +1016,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
if response.Machine.ID != "restored" {
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
}
if runtime.bootCalls != 1 {
@ -1033,7 +1036,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
if err != nil {
t.Fatalf("get restored machine: %v", err)
}
if machine.Phase != contracthost.MachinePhaseStarting {
if machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
}
if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" {
@ -1126,7 +1129,7 @@ func TestRestoreSnapshotBootsWithFreshNetworkWhenSourceNetworkInUseOnHost(t *tes
if err != nil {
t.Fatalf("restore snapshot error = %v, want success", err)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
}
if runtime.bootCalls != 1 {
@ -1254,8 +1257,15 @@ func TestGuestKernelArgsRemovesPCIOffWhenPCIEnabled(t *testing.T) {
}
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, _ firecracker.MachineState) (*guestReadyResult, error) {
guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
if guestSSHPublicKey == "" {
guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3"
}
return &guestReadyResult{
ReadyNonce: record.GuestReadyNonce,
GuestSSHPublicKey: guestSSHPublicKey,
}, nil
}
}

View file

@ -348,6 +348,42 @@ func (d *Daemon) writeBackendSSHPublicKey(privateKeyPath string, publicKeyPath s
return nil
}
type guestSSHHostKeyPair struct {
PrivateKey []byte
PublicKey string
}
func generateGuestSSHHostKeyPair(ctx context.Context) (*guestSSHHostKeyPair, error) {
stagingDir, err := os.MkdirTemp("", "guest-ssh-hostkey-*")
if err != nil {
return nil, fmt.Errorf("create guest ssh host key staging dir: %w", err)
}
defer func() {
_ = os.RemoveAll(stagingDir)
}()
privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key")
command := exec.CommandContext(ctx, "ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-C", "", "-f", privateKeyPath)
output, err := command.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("generate guest ssh host keypair: %w: %s", err, strings.TrimSpace(string(output)))
}
privateKey, err := os.ReadFile(privateKeyPath)
if err != nil {
return nil, fmt.Errorf("read guest ssh host private key %q: %w", privateKeyPath, err)
}
publicKey, err := os.ReadFile(privateKeyPath + ".pub")
if err != nil {
return nil, fmt.Errorf("read guest ssh host public key %q: %w", privateKeyPath+".pub", err)
}
return &guestSSHHostKeyPair{
PrivateKey: privateKey,
PublicKey: strings.TrimSpace(string(publicKey)),
}, nil
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
@ -441,6 +477,41 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
return nil
}
func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guestSSHHostKeyPair) error {
if keyPair == nil {
return fmt.Errorf("guest ssh host keypair is required")
}
if strings.TrimSpace(keyPair.PublicKey) == "" {
return fmt.Errorf("guest ssh host public key is required")
}
stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "guest-ssh-hostkey-*")
if err != nil {
return fmt.Errorf("create guest ssh host key staging dir: %w", err)
}
defer func() {
_ = os.RemoveAll(stagingDir)
}()
privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key")
if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil {
return fmt.Errorf("write guest ssh host private key staging file: %w", err)
}
if err := replaceExt4File(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key"); err != nil {
return err
}
publicKeyPath := privateKeyPath + ".pub"
if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil {
return fmt.Errorf("write guest ssh host public key staging file: %w", err)
}
if err := replaceExt4File(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub"); err != nil {
return err
}
return nil
}
func injectMachineIdentity(ctx context.Context, imagePath string, machineID contracthost.MachineID) error {
machineName := strings.TrimSpace(string(machineID))
if machineName == "" {

View file

@ -20,7 +20,7 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos
if machineName == "" {
return fmt.Errorf("machine id is required")
}
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
mmds, err := d.guestMetadataSpec(machineID, guestConfig, "")
if err != nil {
return err
}

View file

@ -25,6 +25,7 @@ type guestMetadataPayload struct {
Version string `json:"version"`
MachineID string `json:"machine_id"`
Hostname string `json:"hostname"`
ReadyNonce string `json:"ready_nonce,omitempty"`
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"`
@ -55,7 +56,7 @@ func guestHostname(machineID contracthost.MachineID, guestConfig *contracthost.G
return strings.TrimSpace(string(machineID))
}
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) {
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig, readyNonce string) (*firecracker.MMDSSpec, error) {
name := guestHostname(machineID, guestConfig)
if name == "" {
return nil, fmt.Errorf("machine id is required")
@ -67,6 +68,7 @@ func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig
Version: defaultMMDSPayloadVersion,
MachineID: name,
Hostname: name,
ReadyNonce: strings.TrimSpace(readyNonce),
AuthorizedKeys: nil,
TrustedUserCAKeys: nil,
},

View file

@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"net"
"path/filepath"
@ -13,20 +14,32 @@ import (
"strings"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const (
defaultGuestPersonalizationVsockID = "microagent-personalizer"
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
defaultGuestPersonalizationVsockPort = uint32(1024)
defaultGuestPersonalizationTimeout = 2 * time.Second
defaultGuestPersonalizationTimeout = 15 * time.Second
guestPersonalizationRetryInterval = 100 * time.Millisecond
minGuestVsockCID = uint32(3)
maxGuestVsockCID = uint32(1<<31 - 1)
)
type guestPersonalizationResponse struct {
Status string `json:"status"`
ReadyNonce string `json:"ready_nonce,omitempty"`
GuestSSHPublicKey string `json:"guest_ssh_public_key,omitempty"`
Error string `json:"error,omitempty"`
}
type guestReadyRequest struct {
ReadyNonce string `json:"ready_nonce,omitempty"`
}
func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec {
return &firecracker.VsockSpec{
ID: defaultGuestPersonalizationVsockID,
@ -41,53 +54,46 @@ func guestVsockCID(machineID contracthost.MachineID) uint32 {
return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space
}
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) {
if record == nil {
return fmt.Errorf("machine record is required")
return nil, fmt.Errorf("machine record is required")
}
personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout)
defer cancel()
mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig)
response, err := sendGuestPersonalization(personalizeCtx, state, guestReadyRequest{
ReadyNonce: strings.TrimSpace(record.GuestReadyNonce),
})
if err != nil {
return err
return nil, fmt.Errorf("wait for guest ready over vsock: %w", err)
}
envelope, ok := mmds.Data.(guestMetadataEnvelope)
if !ok {
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
if !strings.EqualFold(strings.TrimSpace(response.Status), "ok") {
message := strings.TrimSpace(response.Error)
if message == "" {
message = fmt.Sprintf("unexpected guest personalization status %q", strings.TrimSpace(response.Status))
}
return nil, errors.New(message)
}
if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil {
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err))
}
if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil {
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err))
}
return nil
return &guestReadyResult{
ReadyNonce: strings.TrimSpace(response.ReadyNonce),
GuestSSHPublicKey: strings.TrimSpace(response.GuestSSHPublicKey),
}, nil
}
func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error {
fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig)
if fallbackErr == nil {
return nil
}
return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr)
}
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error {
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestReadyRequest) (*guestPersonalizationResponse, error) {
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal guest personalization payload: %w", err)
return nil, fmt.Errorf("marshal guest personalization payload: %w", err)
}
vsockPath, err := guestVsockHostPath(state)
if err != nil {
return err
return nil, err
}
connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath)
connection, err := dialGuestPersonalization(ctx, vsockPath)
if err != nil {
return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
return nil, err
}
defer func() {
_ = connection.Close()
@ -96,27 +102,28 @@ func sendGuestPersonalization(ctx context.Context, state firecracker.MachineStat
reader := bufio.NewReader(connection)
if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil {
return fmt.Errorf("write vsock connect request: %w", err)
return nil, fmt.Errorf("write vsock connect request: %w", err)
}
response, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("read vsock connect response: %w", err)
return nil, fmt.Errorf("read vsock connect response: %w", err)
}
if !strings.HasPrefix(strings.TrimSpace(response), "OK ") {
return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
return nil, fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
}
if _, err := connection.Write(append(payloadBytes, '\n')); err != nil {
return fmt.Errorf("write guest personalization payload: %w", err)
return nil, fmt.Errorf("write guest personalization payload: %w", err)
}
response, err = reader.ReadString('\n')
if err != nil {
return fmt.Errorf("read guest personalization response: %w", err)
return nil, fmt.Errorf("read guest personalization response: %w", err)
}
if strings.TrimSpace(response) != "OK" {
return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response))
var payloadResponse guestPersonalizationResponse
if err := json.Unmarshal([]byte(strings.TrimSpace(response)), &payloadResponse); err != nil {
return nil, fmt.Errorf("decode guest personalization response %q: %w", strings.TrimSpace(response), err)
}
return nil
return &payloadResponse, nil
}
func guestVsockHostPath(state firecracker.MachineState) (string, error) {
@ -133,3 +140,25 @@ func setConnectionDeadline(ctx context.Context, connection net.Conn) {
}
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
}
func dialGuestPersonalization(ctx context.Context, vsockPath string) (net.Conn, error) {
dialer := &net.Dialer{}
for {
connection, err := dialer.DialContext(ctx, "unix", vsockPath)
if err == nil {
return connection, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(guestPersonalizationRetryInterval):
}
}
}

View file

@ -10,6 +10,8 @@ import (
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
@ -50,7 +52,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
if record.Phase == contracthost.MachinePhaseStarting {
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
reconciled, err := d.reconcileMachine(ctx, id)
if err != nil {
return nil, err
}
return &contracthost.GetMachineResponse{Machine: machineToContract(*reconciled)}, nil
}
if record.Phase != contracthost.MachinePhaseStopped {
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
@ -71,21 +77,38 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
}
}()
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
if err != nil {
return nil, err
}
artifact, err := d.store.GetArtifact(ctx, record.Artifact)
if err != nil {
return nil, err
}
userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs)
if err != nil {
var (
systemVolume *model.VolumeRecord
artifact *model.ArtifactRecord
userVolumes []model.VolumeRecord
readyNonce string
)
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
systemVolume, err = d.store.GetVolume(groupCtx, record.SystemVolumeID)
return err
})
group.Go(func() error {
var err error
artifact, err = d.store.GetArtifact(groupCtx, record.Artifact)
return err
})
group.Go(func() error {
var err error
userVolumes, err = d.loadAttachableUserVolumes(groupCtx, id, record.UserVolumeIDs)
return err
})
group.Go(func() error {
var err error
readyNonce, err = newGuestReadyNonce()
return err
})
if err := group.Wait(); err != nil {
return nil, err
}
repairDirtyFilesystem(systemVolume.Path)
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig)
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig, readyNonce)
if err != nil {
return nil, err
}
@ -100,7 +123,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = ""
record.GuestReadyNonce = readyNonce
record.Phase = contracthost.MachinePhaseStarting
record.Error = ""
record.PID = state.PID
@ -112,6 +135,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
return nil, err
}
record, err = d.completeMachineStartup(ctx, record, *state)
if err != nil {
return nil, err
}
clearOperation = true
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
@ -376,44 +404,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
return nil, err
}
if record.Phase == contracthost.MachinePhaseStarting {
if state.Phase != firecracker.PhaseRunning {
return d.failMachineStartup(ctx, record, state.Error)
}
ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts())
if err != nil {
return nil, err
}
if !ready {
return record, nil
}
if err := d.personalizeGuest(ctx, record, *state); err != nil {
fmt.Fprintf(os.Stderr, "warning: guest personalization for %q failed: %v\n", record.ID, err)
}
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
if err != nil {
fmt.Fprintf(os.Stderr, "warning: read guest ssh public key for %q failed: %v\n", record.ID, err)
guestSSHPublicKey = record.GuestSSHPublicKey
}
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = guestSSHPublicKey
record.Phase = contracthost.MachinePhaseRunning
record.Error = ""
record.PID = state.PID
record.SocketPath = state.SocketPath
record.StartedAt = state.StartedAt
if err := d.store.UpdateMachine(ctx, *record); err != nil {
return nil, err
}
if err := d.ensureMachineRelays(ctx, record); err != nil {
return d.failMachineStartup(ctx, record, err.Error())
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
d.stopMachineRelays(record.ID)
return d.failMachineStartup(ctx, record, err.Error())
}
return record, nil
return d.completeMachineStartup(ctx, record, *state)
}
if state.Phase == firecracker.PhaseRunning {
if err := d.ensureMachineRelays(ctx, record); err != nil {
@ -450,7 +441,7 @@ func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRe
record.Phase = contracthost.MachinePhaseFailed
record.Error = strings.TrimSpace(failureReason)
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = ""
record.GuestReadyNonce = ""
record.PID = 0
record.SocketPath = ""
record.RuntimeHost = ""
@ -511,6 +502,7 @@ func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRec
record.Phase = contracthost.MachinePhaseStopped
record.Error = ""
record.GuestReadyNonce = ""
record.PID = 0
record.SocketPath = ""
record.RuntimeHost = ""

View file

@ -299,7 +299,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
assertOperationCount(t, baseStore, 1)
}
func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -386,16 +386,16 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
if err != nil {
t.Fatalf("StartMachine error = %v", err)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
}
machine, err := baseStore.GetMachine(context.Background(), "vm-start")
if err != nil {
t.Fatalf("get machine: %v", err)
}
if machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting)
if machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseRunning)
}
if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" {
t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice)
@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
}
}
func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
@ -510,8 +510,8 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
if err != nil {
t.Fatalf("RestoreSnapshot returned error: %v", err)
}
if response.Machine.Phase != contracthost.MachinePhaseStarting {
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
}
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil {
t.Fatalf("restored system volume record should exist: %v", err)

View file

@ -245,10 +245,33 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
return nil, err
}
defer cleanupRestoreArtifacts()
artifact, err := d.ensureArtifact(ctx, req.Artifact)
if err != nil {
var (
artifact *model.ArtifactRecord
guestHostKey *guestSSHHostKeyPair
readyNonce string
)
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
artifact, err = d.ensureArtifact(groupCtx, req.Artifact)
if err != nil {
return fmt.Errorf("ensure artifact for restore: %w", err)
}
return nil
})
group.Go(func() error {
var err error
guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx)
return err
})
group.Go(func() error {
var err error
readyNonce, err = newGuestReadyNonce()
return err
})
if err := group.Wait(); err != nil {
clearOperation = true
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
return nil, err
}
// COW-copy system disk from snapshot to new machine's disk dir.
@ -280,6 +303,10 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
clearOperation = true
return nil, fmt.Errorf("inject guest config for restore: %w", err)
}
if err := injectGuestSSHHostKey(ctx, newSystemDiskPath, guestHostKey); err != nil {
clearOperation = true
return nil, fmt.Errorf("inject guest ssh host key for restore: %w", err)
}
type restoredUserVolume struct {
ID contracthost.VolumeID
@ -313,7 +340,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
Path: volume.Path,
})
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig)
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig, readyNonce)
if err != nil {
clearOperation = true
return nil, fmt.Errorf("build machine spec for restore: %w", err)
@ -376,7 +403,8 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
RuntimeHost: machineState.RuntimeHost,
TapDevice: machineState.TapName,
Ports: defaultMachinePorts(),
GuestSSHPublicKey: "",
GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey),
GuestReadyNonce: readyNonce,
Phase: contracthost.MachinePhaseStarting,
PID: machineState.PID,
SocketPath: machineState.SocketPath,
@ -393,10 +421,15 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
return nil, err
}
record, err := d.completeMachineStartup(ctx, &machineRecord, *machineState)
if err != nil {
return nil, err
}
removeMachineDiskDirOnFailure = false
clearOperation = true
return &contracthost.RestoreSnapshotResponse{
Machine: machineToContract(machineRecord),
Machine: machineToContract(*record),
}, nil
}

View file

@ -0,0 +1,77 @@
package daemon
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type guestReadyResult struct {
ReadyNonce string
GuestSSHPublicKey string
}
func newGuestReadyNonce() (string, error) {
var bytes [16]byte
if _, err := rand.Read(bytes[:]); err != nil {
return "", fmt.Errorf("generate guest ready nonce: %w", err)
}
return hex.EncodeToString(bytes[:]), nil
}
func (d *Daemon) completeMachineStartup(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*model.MachineRecord, error) {
if record == nil {
return nil, fmt.Errorf("machine record is required")
}
if state.Phase != firecracker.PhaseRunning {
failureReason := strings.TrimSpace(state.Error)
if failureReason == "" {
failureReason = "machine did not reach running phase"
}
return d.failMachineStartup(ctx, record, failureReason)
}
ready, err := d.personalizeGuest(ctx, record, state)
if err != nil {
return d.failMachineStartup(ctx, record, err.Error())
}
expectedNonce := strings.TrimSpace(record.GuestReadyNonce)
receivedNonce := strings.TrimSpace(ready.ReadyNonce)
if expectedNonce != "" && receivedNonce != expectedNonce {
return d.failMachineStartup(ctx, record, "guest ready nonce mismatch")
}
expectedGuestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
guestSSHPublicKey := strings.TrimSpace(ready.GuestSSHPublicKey)
if guestSSHPublicKey == "" {
if expectedGuestSSHPublicKey == "" {
return d.failMachineStartup(ctx, record, "guest ready response missing ssh host key")
}
guestSSHPublicKey = expectedGuestSSHPublicKey
}
if expectedGuestSSHPublicKey != "" && guestSSHPublicKey != expectedGuestSSHPublicKey {
return d.failMachineStartup(ctx, record, "guest ssh host key mismatch")
}
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = defaultMachinePorts()
record.GuestSSHPublicKey = guestSSHPublicKey
record.GuestReadyNonce = ""
record.Phase = contracthost.MachinePhaseRunning
record.Error = ""
record.PID = state.PID
record.SocketPath = state.SocketPath
record.StartedAt = state.StartedAt
if err := d.store.UpdateMachine(ctx, *record); err != nil {
return nil, err
}
return record, nil
}

View file

@ -8,15 +8,15 @@ import (
)
type vmConfig struct {
BootSource vmBootSource `json:"boot-source"`
Drives []vmDrive `json:"drives"`
MachineConfig vmMachineConfig `json:"machine-config"`
NetworkInterfaces []vmNetworkIface `json:"network-interfaces"`
Vsock *vmVsock `json:"vsock,omitempty"`
Logger *vmLogger `json:"logger,omitempty"`
MMDSConfig *vmMMDSConfig `json:"mmds-config,omitempty"`
Entropy *vmEntropy `json:"entropy,omitempty"`
Serial *vmSerial `json:"serial,omitempty"`
BootSource vmBootSource `json:"boot-source"`
Drives []vmDrive `json:"drives"`
MachineConfig vmMachineConfig `json:"machine-config"`
NetworkInterfaces []vmNetworkIface `json:"network-interfaces"`
Vsock *vmVsock `json:"vsock,omitempty"`
Logger *vmLogger `json:"logger,omitempty"`
MMDSConfig *vmMMDSConfig `json:"mmds-config,omitempty"`
Entropy *vmEntropy `json:"entropy,omitempty"`
Serial *vmSerial `json:"serial,omitempty"`
}
type vmBootSource struct {

View file

@ -36,6 +36,7 @@ type MachineRecord struct {
TapDevice string
Ports []contracthost.MachinePort
GuestSSHPublicKey string
GuestReadyNonce string
Phase contracthost.MachinePhase
Error string
PID int