mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 06:04:38 +00:00
feat: remove wakeup path, return on create, host managed ssh-keygen, ack nonce dep
This commit is contained in:
parent
0e4b18f10b
commit
4a9dc91ebf
13 changed files with 423 additions and 170 deletions
|
|
@ -5,8 +5,11 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||
|
|
@ -52,13 +55,34 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
}
|
||||
}()
|
||||
|
||||
artifact, err := d.ensureArtifact(ctx, req.Artifact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userVolumes, err := d.loadAttachableUserVolumes(ctx, req.MachineID, req.UserVolumeIDs)
|
||||
if err != nil {
|
||||
var (
|
||||
artifact *model.ArtifactRecord
|
||||
userVolumes []model.VolumeRecord
|
||||
guestHostKey *guestSSHHostKeyPair
|
||||
readyNonce string
|
||||
)
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
artifact, err = d.ensureArtifact(groupCtx, req.Artifact)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
userVolumes, err = d.loadAttachableUserVolumes(groupCtx, req.MachineID, req.UserVolumeIDs)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
readyNonce, err = newGuestReadyNonce()
|
||||
return err
|
||||
})
|
||||
if err := group.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -86,8 +110,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
|
||||
return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err)
|
||||
}
|
||||
if err := injectGuestSSHHostKey(ctx, systemVolumePath, guestHostKey); err != nil {
|
||||
return nil, fmt.Errorf("inject guest ssh host key for %q: %w", req.MachineID, err)
|
||||
}
|
||||
|
||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig)
|
||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig, readyNonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -135,19 +162,21 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
}
|
||||
|
||||
record := model.MachineRecord{
|
||||
ID: req.MachineID,
|
||||
Artifact: req.Artifact,
|
||||
GuestConfig: cloneGuestConfig(guestConfig),
|
||||
SystemVolumeID: systemVolumeRecord.ID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||
RuntimeHost: state.RuntimeHost,
|
||||
TapDevice: state.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
Phase: contracthost.MachinePhaseStarting,
|
||||
PID: state.PID,
|
||||
SocketPath: state.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: state.StartedAt,
|
||||
ID: req.MachineID,
|
||||
Artifact: req.Artifact,
|
||||
GuestConfig: cloneGuestConfig(guestConfig),
|
||||
SystemVolumeID: systemVolumeRecord.ID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||
RuntimeHost: state.RuntimeHost,
|
||||
TapDevice: state.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey),
|
||||
GuestReadyNonce: readyNonce,
|
||||
Phase: contracthost.MachinePhaseStarting,
|
||||
PID: state.PID,
|
||||
SocketPath: state.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: state.StartedAt,
|
||||
}
|
||||
if err := d.store.CreateMachine(ctx, record); err != nil {
|
||||
for _, volume := range userVolumes {
|
||||
|
|
@ -159,12 +188,17 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
return nil, err
|
||||
}
|
||||
|
||||
recordReady, err := d.completeMachineStartup(ctx, &record, *state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
removeSystemVolumeOnFailure = false
|
||||
clearOperation = true
|
||||
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
|
||||
return &contracthost.CreateMachineResponse{Machine: machineToContract(*recordReady)}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) {
|
||||
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig, readyNonce string) (firecracker.MachineSpec, error) {
|
||||
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
|
||||
for i, volume := range userVolumes {
|
||||
drives = append(drives, firecracker.DriveSpec{
|
||||
|
|
@ -176,7 +210,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
|
|||
})
|
||||
}
|
||||
|
||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
|
||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig, readyNonce)
|
||||
if err != nil {
|
||||
return firecracker.MachineSpec{}, err
|
||||
}
|
||||
|
|
@ -221,10 +255,14 @@ func (d *Daemon) ensureArtifact(ctx context.Context, ref contracthost.ArtifactRe
|
|||
|
||||
kernelPath := filepath.Join(dir, "kernel")
|
||||
rootFSPath := filepath.Join(dir, "rootfs")
|
||||
if err := downloadFile(ctx, ref.KernelImageURL, kernelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := downloadFile(ctx, ref.RootFSURL, rootFSPath); err != nil {
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
return downloadFile(groupCtx, ref.KernelImageURL, kernelPath)
|
||||
})
|
||||
group.Go(func() error {
|
||||
return downloadFile(groupCtx, ref.RootFSURL, rootFSPath)
|
||||
})
|
||||
if err := group.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ type Daemon struct {
|
|||
injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error
|
||||
syncGuestFilesystem func(context.Context, string) error
|
||||
shutdownGuest func(context.Context, string) error
|
||||
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error
|
||||
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error)
|
||||
|
||||
locksMu sync.Mutex
|
||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
t.Fatalf("create machine: %v", err)
|
||||
}
|
||||
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
|
||||
}
|
||||
if response.Machine.RuntimeHost != "127.0.0.1" {
|
||||
|
|
@ -230,7 +230,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
if machine.SystemVolumeID != "vm-1-system" {
|
||||
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
|
||||
}
|
||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
||||
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("stored machine phase mismatch: got %q", machine.Phase)
|
||||
}
|
||||
if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 {
|
||||
|
|
@ -401,7 +401,7 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
|
|||
t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") })
|
||||
|
||||
personalized := false
|
||||
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
|
||||
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) {
|
||||
personalized = true
|
||||
if record.ID != "vm-starting" {
|
||||
t.Fatalf("personalized machine mismatch: got %q", record.ID)
|
||||
|
|
@ -409,9 +409,15 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
|
|||
if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 {
|
||||
t.Fatalf("personalized state mismatch: %#v", state)
|
||||
}
|
||||
return nil
|
||||
guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
|
||||
if guestSSHPublicKey == "" {
|
||||
guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3"
|
||||
}
|
||||
return &guestReadyResult{
|
||||
ReadyNonce: record.GuestReadyNonce,
|
||||
GuestSSHPublicKey: guestSSHPublicKey,
|
||||
}, nil
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
ID: "vm-starting",
|
||||
|
|
@ -455,9 +461,9 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error {
|
||||
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) {
|
||||
t.Fatalf("ListMachines should not reconcile guest personalization")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
||||
t.Fatalf("ListMachines should not read guest ssh public key")
|
||||
|
|
@ -492,7 +498,7 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
|
||||
func TestReconcileStartingMachineFailsWhenHandshakeFails(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
|
|
@ -511,11 +517,8 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error {
|
||||
return errors.New("vsock EOF")
|
||||
}
|
||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
||||
return "", errors.New("Permission denied")
|
||||
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) {
|
||||
return nil, errors.New("vsock EOF")
|
||||
}
|
||||
|
||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
|
|
@ -542,14 +545,14 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("get machine: %v", err)
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseRunning)
|
||||
if record.Phase != contracthost.MachinePhaseFailed {
|
||||
t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseFailed)
|
||||
}
|
||||
if record.GuestSSHPublicKey != "ssh-ed25519 AAAAExistingHostKey" {
|
||||
t.Fatalf("guest ssh public key = %q, want preserved value", record.GuestSSHPublicKey)
|
||||
if !strings.Contains(record.Error, "vsock EOF") {
|
||||
t.Fatalf("failure reason = %q, want vsock error", record.Error)
|
||||
}
|
||||
if len(runtime.deleteCalls) != 0 {
|
||||
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
|
||||
if len(runtime.deleteCalls) != 1 {
|
||||
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -756,7 +759,7 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
|
|||
if response.Machine.ID != "restored" {
|
||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||
}
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||
}
|
||||
if runtime.bootCalls != 1 {
|
||||
|
|
@ -1013,7 +1016,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
|||
if response.Machine.ID != "restored" {
|
||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||
}
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||
}
|
||||
if runtime.bootCalls != 1 {
|
||||
|
|
@ -1033,7 +1036,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("get restored machine: %v", err)
|
||||
}
|
||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
||||
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
|
||||
}
|
||||
if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" {
|
||||
|
|
@ -1126,7 +1129,7 @@ func TestRestoreSnapshotBootsWithFreshNetworkWhenSourceNetworkInUseOnHost(t *tes
|
|||
if err != nil {
|
||||
t.Fatalf("restore snapshot error = %v, want success", err)
|
||||
}
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||
}
|
||||
if runtime.bootCalls != 1 {
|
||||
|
|
@ -1254,8 +1257,15 @@ func TestGuestKernelArgsRemovesPCIOffWhenPCIEnabled(t *testing.T) {
|
|||
}
|
||||
|
||||
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
|
||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
||||
return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil
|
||||
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, _ firecracker.MachineState) (*guestReadyResult, error) {
|
||||
guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
|
||||
if guestSSHPublicKey == "" {
|
||||
guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3"
|
||||
}
|
||||
return &guestReadyResult{
|
||||
ReadyNonce: record.GuestReadyNonce,
|
||||
GuestSSHPublicKey: guestSSHPublicKey,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -348,6 +348,42 @@ func (d *Daemon) writeBackendSSHPublicKey(privateKeyPath string, publicKeyPath s
|
|||
return nil
|
||||
}
|
||||
|
||||
type guestSSHHostKeyPair struct {
|
||||
PrivateKey []byte
|
||||
PublicKey string
|
||||
}
|
||||
|
||||
func generateGuestSSHHostKeyPair(ctx context.Context) (*guestSSHHostKeyPair, error) {
|
||||
stagingDir, err := os.MkdirTemp("", "guest-ssh-hostkey-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create guest ssh host key staging dir: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = os.RemoveAll(stagingDir)
|
||||
}()
|
||||
|
||||
privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key")
|
||||
command := exec.CommandContext(ctx, "ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-C", "", "-f", privateKeyPath)
|
||||
output, err := command.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate guest ssh host keypair: %w: %s", err, strings.TrimSpace(string(output)))
|
||||
}
|
||||
|
||||
privateKey, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read guest ssh host private key %q: %w", privateKeyPath, err)
|
||||
}
|
||||
publicKey, err := os.ReadFile(privateKeyPath + ".pub")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read guest ssh host public key %q: %w", privateKeyPath+".pub", err)
|
||||
}
|
||||
|
||||
return &guestSSHHostKeyPair{
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: strings.TrimSpace(string(publicKey)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
|
|
@ -441,6 +477,41 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
|
|||
return nil
|
||||
}
|
||||
|
||||
func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guestSSHHostKeyPair) error {
|
||||
if keyPair == nil {
|
||||
return fmt.Errorf("guest ssh host keypair is required")
|
||||
}
|
||||
if strings.TrimSpace(keyPair.PublicKey) == "" {
|
||||
return fmt.Errorf("guest ssh host public key is required")
|
||||
}
|
||||
|
||||
stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "guest-ssh-hostkey-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create guest ssh host key staging dir: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = os.RemoveAll(stagingDir)
|
||||
}()
|
||||
|
||||
privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key")
|
||||
if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil {
|
||||
return fmt.Errorf("write guest ssh host private key staging file: %w", err)
|
||||
}
|
||||
if err := replaceExt4File(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKeyPath := privateKeyPath + ".pub"
|
||||
if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil {
|
||||
return fmt.Errorf("write guest ssh host public key staging file: %w", err)
|
||||
}
|
||||
if err := replaceExt4File(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func injectMachineIdentity(ctx context.Context, imagePath string, machineID contracthost.MachineID) error {
|
||||
machineName := strings.TrimSpace(string(machineID))
|
||||
if machineName == "" {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos
|
|||
if machineName == "" {
|
||||
return fmt.Errorf("machine id is required")
|
||||
}
|
||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
|
||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ type guestMetadataPayload struct {
|
|||
Version string `json:"version"`
|
||||
MachineID string `json:"machine_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
ReadyNonce string `json:"ready_nonce,omitempty"`
|
||||
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
||||
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
|
||||
LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"`
|
||||
|
|
@ -55,7 +56,7 @@ func guestHostname(machineID contracthost.MachineID, guestConfig *contracthost.G
|
|||
return strings.TrimSpace(string(machineID))
|
||||
}
|
||||
|
||||
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) {
|
||||
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig, readyNonce string) (*firecracker.MMDSSpec, error) {
|
||||
name := guestHostname(machineID, guestConfig)
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("machine id is required")
|
||||
|
|
@ -67,6 +68,7 @@ func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig
|
|||
Version: defaultMMDSPayloadVersion,
|
||||
MachineID: name,
|
||||
Hostname: name,
|
||||
ReadyNonce: strings.TrimSpace(readyNonce),
|
||||
AuthorizedKeys: nil,
|
||||
TrustedUserCAKeys: nil,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
|
|
@ -13,20 +14,32 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultGuestPersonalizationVsockID = "microagent-personalizer"
|
||||
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
|
||||
defaultGuestPersonalizationVsockPort = uint32(1024)
|
||||
defaultGuestPersonalizationTimeout = 2 * time.Second
|
||||
defaultGuestPersonalizationTimeout = 15 * time.Second
|
||||
guestPersonalizationRetryInterval = 100 * time.Millisecond
|
||||
minGuestVsockCID = uint32(3)
|
||||
maxGuestVsockCID = uint32(1<<31 - 1)
|
||||
)
|
||||
|
||||
type guestPersonalizationResponse struct {
|
||||
Status string `json:"status"`
|
||||
ReadyNonce string `json:"ready_nonce,omitempty"`
|
||||
GuestSSHPublicKey string `json:"guest_ssh_public_key,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type guestReadyRequest struct {
|
||||
ReadyNonce string `json:"ready_nonce,omitempty"`
|
||||
}
|
||||
|
||||
func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec {
|
||||
return &firecracker.VsockSpec{
|
||||
ID: defaultGuestPersonalizationVsockID,
|
||||
|
|
@ -41,53 +54,46 @@ func guestVsockCID(machineID contracthost.MachineID) uint32 {
|
|||
return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space
|
||||
}
|
||||
|
||||
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
|
||||
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) {
|
||||
if record == nil {
|
||||
return fmt.Errorf("machine record is required")
|
||||
return nil, fmt.Errorf("machine record is required")
|
||||
}
|
||||
|
||||
personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout)
|
||||
defer cancel()
|
||||
|
||||
mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig)
|
||||
response, err := sendGuestPersonalization(personalizeCtx, state, guestReadyRequest{
|
||||
ReadyNonce: strings.TrimSpace(record.GuestReadyNonce),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("wait for guest ready over vsock: %w", err)
|
||||
}
|
||||
envelope, ok := mmds.Data.(guestMetadataEnvelope)
|
||||
if !ok {
|
||||
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
|
||||
if !strings.EqualFold(strings.TrimSpace(response.Status), "ok") {
|
||||
message := strings.TrimSpace(response.Error)
|
||||
if message == "" {
|
||||
message = fmt.Sprintf("unexpected guest personalization status %q", strings.TrimSpace(response.Status))
|
||||
}
|
||||
return nil, errors.New(message)
|
||||
}
|
||||
|
||||
if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil {
|
||||
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err))
|
||||
}
|
||||
if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil {
|
||||
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err))
|
||||
}
|
||||
return nil
|
||||
return &guestReadyResult{
|
||||
ReadyNonce: strings.TrimSpace(response.ReadyNonce),
|
||||
GuestSSHPublicKey: strings.TrimSpace(response.GuestSSHPublicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error {
|
||||
fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig)
|
||||
if fallbackErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr)
|
||||
}
|
||||
|
||||
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error {
|
||||
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestReadyRequest) (*guestPersonalizationResponse, error) {
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal guest personalization payload: %w", err)
|
||||
return nil, fmt.Errorf("marshal guest personalization payload: %w", err)
|
||||
}
|
||||
|
||||
vsockPath, err := guestVsockHostPath(state)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath)
|
||||
connection, err := dialGuestPersonalization(ctx, vsockPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = connection.Close()
|
||||
|
|
@ -96,27 +102,28 @@ func sendGuestPersonalization(ctx context.Context, state firecracker.MachineStat
|
|||
|
||||
reader := bufio.NewReader(connection)
|
||||
if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil {
|
||||
return fmt.Errorf("write vsock connect request: %w", err)
|
||||
return nil, fmt.Errorf("write vsock connect request: %w", err)
|
||||
}
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("read vsock connect response: %w", err)
|
||||
return nil, fmt.Errorf("read vsock connect response: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(strings.TrimSpace(response), "OK ") {
|
||||
return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
|
||||
return nil, fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
|
||||
}
|
||||
|
||||
if _, err := connection.Write(append(payloadBytes, '\n')); err != nil {
|
||||
return fmt.Errorf("write guest personalization payload: %w", err)
|
||||
return nil, fmt.Errorf("write guest personalization payload: %w", err)
|
||||
}
|
||||
response, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("read guest personalization response: %w", err)
|
||||
return nil, fmt.Errorf("read guest personalization response: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(response) != "OK" {
|
||||
return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response))
|
||||
var payloadResponse guestPersonalizationResponse
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(response)), &payloadResponse); err != nil {
|
||||
return nil, fmt.Errorf("decode guest personalization response %q: %w", strings.TrimSpace(response), err)
|
||||
}
|
||||
return nil
|
||||
return &payloadResponse, nil
|
||||
}
|
||||
|
||||
func guestVsockHostPath(state firecracker.MachineState) (string, error) {
|
||||
|
|
@ -133,3 +140,25 @@ func setConnectionDeadline(ctx context.Context, connection net.Conn) {
|
|||
}
|
||||
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
|
||||
}
|
||||
|
||||
func dialGuestPersonalization(ctx context.Context, vsockPath string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{}
|
||||
for {
|
||||
connection, err := dialer.DialContext(ctx, "unix", vsockPath)
|
||||
if err == nil {
|
||||
return connection, nil
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(guestPersonalizationRetryInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||
|
|
@ -50,7 +52,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||
}
|
||||
if record.Phase == contracthost.MachinePhaseStarting {
|
||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||
reconciled, err := d.reconcileMachine(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*reconciled)}, nil
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseStopped {
|
||||
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
|
||||
|
|
@ -71,21 +77,38 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
}
|
||||
}()
|
||||
|
||||
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifact, err := d.store.GetArtifact(ctx, record.Artifact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs)
|
||||
if err != nil {
|
||||
var (
|
||||
systemVolume *model.VolumeRecord
|
||||
artifact *model.ArtifactRecord
|
||||
userVolumes []model.VolumeRecord
|
||||
readyNonce string
|
||||
)
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
systemVolume, err = d.store.GetVolume(groupCtx, record.SystemVolumeID)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
artifact, err = d.store.GetArtifact(groupCtx, record.Artifact)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
userVolumes, err = d.loadAttachableUserVolumes(groupCtx, id, record.UserVolumeIDs)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
readyNonce, err = newGuestReadyNonce()
|
||||
return err
|
||||
})
|
||||
if err := group.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repairDirtyFilesystem(systemVolume.Path)
|
||||
|
||||
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig)
|
||||
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig, readyNonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -100,7 +123,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
record.RuntimeHost = state.RuntimeHost
|
||||
record.TapDevice = state.TapName
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = ""
|
||||
record.GuestReadyNonce = readyNonce
|
||||
record.Phase = contracthost.MachinePhaseStarting
|
||||
record.Error = ""
|
||||
record.PID = state.PID
|
||||
|
|
@ -112,6 +135,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
return nil, err
|
||||
}
|
||||
|
||||
record, err = d.completeMachineStartup(ctx, record, *state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clearOperation = true
|
||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||
}
|
||||
|
|
@ -376,44 +404,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
|||
return nil, err
|
||||
}
|
||||
if record.Phase == contracthost.MachinePhaseStarting {
|
||||
if state.Phase != firecracker.PhaseRunning {
|
||||
return d.failMachineStartup(ctx, record, state.Error)
|
||||
}
|
||||
ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ready {
|
||||
return record, nil
|
||||
}
|
||||
if err := d.personalizeGuest(ctx, record, *state); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: guest personalization for %q failed: %v\n", record.ID, err)
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: read guest ssh public key for %q failed: %v\n", record.ID, err)
|
||||
guestSSHPublicKey = record.GuestSSHPublicKey
|
||||
}
|
||||
record.RuntimeHost = state.RuntimeHost
|
||||
record.TapDevice = state.TapName
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||
record.Phase = contracthost.MachinePhaseRunning
|
||||
record.Error = ""
|
||||
record.PID = state.PID
|
||||
record.SocketPath = state.SocketPath
|
||||
record.StartedAt = state.StartedAt
|
||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
return d.failMachineStartup(ctx, record, err.Error())
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||
d.stopMachineRelays(record.ID)
|
||||
return d.failMachineStartup(ctx, record, err.Error())
|
||||
}
|
||||
return record, nil
|
||||
return d.completeMachineStartup(ctx, record, *state)
|
||||
}
|
||||
if state.Phase == firecracker.PhaseRunning {
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
|
|
@ -450,7 +441,7 @@ func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRe
|
|||
record.Phase = contracthost.MachinePhaseFailed
|
||||
record.Error = strings.TrimSpace(failureReason)
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = ""
|
||||
record.GuestReadyNonce = ""
|
||||
record.PID = 0
|
||||
record.SocketPath = ""
|
||||
record.RuntimeHost = ""
|
||||
|
|
@ -511,6 +502,7 @@ func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRec
|
|||
|
||||
record.Phase = contracthost.MachinePhaseStopped
|
||||
record.Error = ""
|
||||
record.GuestReadyNonce = ""
|
||||
record.PID = 0
|
||||
record.SocketPath = ""
|
||||
record.RuntimeHost = ""
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
|||
assertOperationCount(t, baseStore, 1)
|
||||
}
|
||||
|
||||
func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||
func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
|
|
@ -386,16 +386,16 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("StartMachine error = %v", err)
|
||||
}
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
|
||||
}
|
||||
|
||||
machine, err := baseStore.GetMachine(context.Background(), "vm-start")
|
||||
if err != nil {
|
||||
t.Fatalf("get machine: %v", err)
|
||||
}
|
||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting)
|
||||
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseRunning)
|
||||
}
|
||||
if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" {
|
||||
t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice)
|
||||
|
|
@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||
func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
|
|
@ -510,8 +510,8 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
|
|||
if err != nil {
|
||||
t.Fatalf("RestoreSnapshot returned error: %v", err)
|
||||
}
|
||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
|
||||
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
|
||||
}
|
||||
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil {
|
||||
t.Fatalf("restored system volume record should exist: %v", err)
|
||||
|
|
|
|||
|
|
@ -245,10 +245,33 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
return nil, err
|
||||
}
|
||||
defer cleanupRestoreArtifacts()
|
||||
artifact, err := d.ensureArtifact(ctx, req.Artifact)
|
||||
if err != nil {
|
||||
var (
|
||||
artifact *model.ArtifactRecord
|
||||
guestHostKey *guestSSHHostKeyPair
|
||||
readyNonce string
|
||||
)
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
artifact, err = d.ensureArtifact(groupCtx, req.Artifact)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ensure artifact for restore: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx)
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
var err error
|
||||
readyNonce, err = newGuestReadyNonce()
|
||||
return err
|
||||
})
|
||||
if err := group.Wait(); err != nil {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// COW-copy system disk from snapshot to new machine's disk dir.
|
||||
|
|
@ -280,6 +303,10 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
clearOperation = true
|
||||
return nil, fmt.Errorf("inject guest config for restore: %w", err)
|
||||
}
|
||||
if err := injectGuestSSHHostKey(ctx, newSystemDiskPath, guestHostKey); err != nil {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("inject guest ssh host key for restore: %w", err)
|
||||
}
|
||||
|
||||
type restoredUserVolume struct {
|
||||
ID contracthost.VolumeID
|
||||
|
|
@ -313,7 +340,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
Path: volume.Path,
|
||||
})
|
||||
}
|
||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig)
|
||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig, readyNonce)
|
||||
if err != nil {
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("build machine spec for restore: %w", err)
|
||||
|
|
@ -376,7 +403,8 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
RuntimeHost: machineState.RuntimeHost,
|
||||
TapDevice: machineState.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
GuestSSHPublicKey: "",
|
||||
GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey),
|
||||
GuestReadyNonce: readyNonce,
|
||||
Phase: contracthost.MachinePhaseStarting,
|
||||
PID: machineState.PID,
|
||||
SocketPath: machineState.SocketPath,
|
||||
|
|
@ -393,10 +421,15 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
return nil, err
|
||||
}
|
||||
|
||||
record, err := d.completeMachineStartup(ctx, &machineRecord, *machineState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
removeMachineDiskDirOnFailure = false
|
||||
clearOperation = true
|
||||
return &contracthost.RestoreSnapshotResponse{
|
||||
Machine: machineToContract(machineRecord),
|
||||
Machine: machineToContract(*record),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
77
internal/daemon/startup.go
Normal file
77
internal/daemon/startup.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
)
|
||||
|
||||
type guestReadyResult struct {
|
||||
ReadyNonce string
|
||||
GuestSSHPublicKey string
|
||||
}
|
||||
|
||||
func newGuestReadyNonce() (string, error) {
|
||||
var bytes [16]byte
|
||||
if _, err := rand.Read(bytes[:]); err != nil {
|
||||
return "", fmt.Errorf("generate guest ready nonce: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(bytes[:]), nil
|
||||
}
|
||||
|
||||
func (d *Daemon) completeMachineStartup(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*model.MachineRecord, error) {
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("machine record is required")
|
||||
}
|
||||
if state.Phase != firecracker.PhaseRunning {
|
||||
failureReason := strings.TrimSpace(state.Error)
|
||||
if failureReason == "" {
|
||||
failureReason = "machine did not reach running phase"
|
||||
}
|
||||
return d.failMachineStartup(ctx, record, failureReason)
|
||||
}
|
||||
|
||||
ready, err := d.personalizeGuest(ctx, record, state)
|
||||
if err != nil {
|
||||
return d.failMachineStartup(ctx, record, err.Error())
|
||||
}
|
||||
|
||||
expectedNonce := strings.TrimSpace(record.GuestReadyNonce)
|
||||
receivedNonce := strings.TrimSpace(ready.ReadyNonce)
|
||||
if expectedNonce != "" && receivedNonce != expectedNonce {
|
||||
return d.failMachineStartup(ctx, record, "guest ready nonce mismatch")
|
||||
}
|
||||
|
||||
expectedGuestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
|
||||
guestSSHPublicKey := strings.TrimSpace(ready.GuestSSHPublicKey)
|
||||
if guestSSHPublicKey == "" {
|
||||
if expectedGuestSSHPublicKey == "" {
|
||||
return d.failMachineStartup(ctx, record, "guest ready response missing ssh host key")
|
||||
}
|
||||
guestSSHPublicKey = expectedGuestSSHPublicKey
|
||||
}
|
||||
if expectedGuestSSHPublicKey != "" && guestSSHPublicKey != expectedGuestSSHPublicKey {
|
||||
return d.failMachineStartup(ctx, record, "guest ssh host key mismatch")
|
||||
}
|
||||
|
||||
record.RuntimeHost = state.RuntimeHost
|
||||
record.TapDevice = state.TapName
|
||||
record.Ports = defaultMachinePorts()
|
||||
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||
record.GuestReadyNonce = ""
|
||||
record.Phase = contracthost.MachinePhaseRunning
|
||||
record.Error = ""
|
||||
record.PID = state.PID
|
||||
record.SocketPath = state.SocketPath
|
||||
record.StartedAt = state.StartedAt
|
||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue