mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 05:02:05 +00:00
feat: remove wakeup path, return on create, host managed ssh-keygen, ack nonce dep
This commit is contained in:
parent
0e4b18f10b
commit
4a9dc91ebf
13 changed files with 423 additions and 170 deletions
|
|
@ -5,8 +5,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
|
|
@ -52,13 +55,34 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
artifact, err := d.ensureArtifact(ctx, req.Artifact)
|
var (
|
||||||
if err != nil {
|
artifact *model.ArtifactRecord
|
||||||
return nil, err
|
userVolumes []model.VolumeRecord
|
||||||
}
|
guestHostKey *guestSSHHostKeyPair
|
||||||
|
readyNonce string
|
||||||
userVolumes, err := d.loadAttachableUserVolumes(ctx, req.MachineID, req.UserVolumeIDs)
|
)
|
||||||
if err != nil {
|
group, groupCtx := errgroup.WithContext(ctx)
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
artifact, err = d.ensureArtifact(groupCtx, req.Artifact)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
userVolumes, err = d.loadAttachableUserVolumes(groupCtx, req.MachineID, req.UserVolumeIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
readyNonce, err = newGuestReadyNonce()
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err := group.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,8 +110,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
|
if err := d.injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
|
||||||
return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err)
|
return nil, fmt.Errorf("inject guest config for %q: %w", req.MachineID, err)
|
||||||
}
|
}
|
||||||
|
if err := injectGuestSSHHostKey(ctx, systemVolumePath, guestHostKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("inject guest ssh host key for %q: %w", req.MachineID, err)
|
||||||
|
}
|
||||||
|
|
||||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig)
|
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath, guestConfig, readyNonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -135,19 +162,21 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
}
|
}
|
||||||
|
|
||||||
record := model.MachineRecord{
|
record := model.MachineRecord{
|
||||||
ID: req.MachineID,
|
ID: req.MachineID,
|
||||||
Artifact: req.Artifact,
|
Artifact: req.Artifact,
|
||||||
GuestConfig: cloneGuestConfig(guestConfig),
|
GuestConfig: cloneGuestConfig(guestConfig),
|
||||||
SystemVolumeID: systemVolumeRecord.ID,
|
SystemVolumeID: systemVolumeRecord.ID,
|
||||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||||
RuntimeHost: state.RuntimeHost,
|
RuntimeHost: state.RuntimeHost,
|
||||||
TapDevice: state.TapName,
|
TapDevice: state.TapName,
|
||||||
Ports: defaultMachinePorts(),
|
Ports: defaultMachinePorts(),
|
||||||
Phase: contracthost.MachinePhaseStarting,
|
GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey),
|
||||||
PID: state.PID,
|
GuestReadyNonce: readyNonce,
|
||||||
SocketPath: state.SocketPath,
|
Phase: contracthost.MachinePhaseStarting,
|
||||||
CreatedAt: now,
|
PID: state.PID,
|
||||||
StartedAt: state.StartedAt,
|
SocketPath: state.SocketPath,
|
||||||
|
CreatedAt: now,
|
||||||
|
StartedAt: state.StartedAt,
|
||||||
}
|
}
|
||||||
if err := d.store.CreateMachine(ctx, record); err != nil {
|
if err := d.store.CreateMachine(ctx, record); err != nil {
|
||||||
for _, volume := range userVolumes {
|
for _, volume := range userVolumes {
|
||||||
|
|
@ -159,12 +188,17 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
recordReady, err := d.completeMachineStartup(ctx, &record, *state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
removeSystemVolumeOnFailure = false
|
removeSystemVolumeOnFailure = false
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
|
return &contracthost.CreateMachineResponse{Machine: machineToContract(*recordReady)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig) (firecracker.MachineSpec, error) {
|
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string, guestConfig *contracthost.GuestConfig, readyNonce string) (firecracker.MachineSpec, error) {
|
||||||
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
|
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
|
||||||
for i, volume := range userVolumes {
|
for i, volume := range userVolumes {
|
||||||
drives = append(drives, firecracker.DriveSpec{
|
drives = append(drives, firecracker.DriveSpec{
|
||||||
|
|
@ -176,7 +210,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
|
mmds, err := d.guestMetadataSpec(machineID, guestConfig, readyNonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return firecracker.MachineSpec{}, err
|
return firecracker.MachineSpec{}, err
|
||||||
}
|
}
|
||||||
|
|
@ -221,10 +255,14 @@ func (d *Daemon) ensureArtifact(ctx context.Context, ref contracthost.ArtifactRe
|
||||||
|
|
||||||
kernelPath := filepath.Join(dir, "kernel")
|
kernelPath := filepath.Join(dir, "kernel")
|
||||||
rootFSPath := filepath.Join(dir, "rootfs")
|
rootFSPath := filepath.Join(dir, "rootfs")
|
||||||
if err := downloadFile(ctx, ref.KernelImageURL, kernelPath); err != nil {
|
group, groupCtx := errgroup.WithContext(ctx)
|
||||||
return nil, err
|
group.Go(func() error {
|
||||||
}
|
return downloadFile(groupCtx, ref.KernelImageURL, kernelPath)
|
||||||
if err := downloadFile(ctx, ref.RootFSURL, rootFSPath); err != nil {
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
return downloadFile(groupCtx, ref.RootFSURL, rootFSPath)
|
||||||
|
})
|
||||||
|
if err := group.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ type Daemon struct {
|
||||||
injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error
|
injectGuestConfig func(context.Context, string, *contracthost.GuestConfig) error
|
||||||
syncGuestFilesystem func(context.Context, string) error
|
syncGuestFilesystem func(context.Context, string) error
|
||||||
shutdownGuest func(context.Context, string) error
|
shutdownGuest func(context.Context, string) error
|
||||||
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error
|
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error)
|
||||||
|
|
||||||
locksMu sync.Mutex
|
locksMu sync.Mutex
|
||||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||||
|
|
|
||||||
|
|
@ -152,7 +152,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
||||||
t.Fatalf("create machine: %v", err)
|
t.Fatalf("create machine: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
|
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
|
||||||
}
|
}
|
||||||
if response.Machine.RuntimeHost != "127.0.0.1" {
|
if response.Machine.RuntimeHost != "127.0.0.1" {
|
||||||
|
|
@ -230,7 +230,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
||||||
if machine.SystemVolumeID != "vm-1-system" {
|
if machine.SystemVolumeID != "vm-1-system" {
|
||||||
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
|
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
|
||||||
}
|
}
|
||||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("stored machine phase mismatch: got %q", machine.Phase)
|
t.Fatalf("stored machine phase mismatch: got %q", machine.Phase)
|
||||||
}
|
}
|
||||||
if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 {
|
if machine.GuestConfig == nil || len(machine.GuestConfig.AuthorizedKeys) == 0 {
|
||||||
|
|
@ -401,7 +401,7 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
|
||||||
t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") })
|
t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") })
|
||||||
|
|
||||||
personalized := false
|
personalized := false
|
||||||
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
|
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) {
|
||||||
personalized = true
|
personalized = true
|
||||||
if record.ID != "vm-starting" {
|
if record.ID != "vm-starting" {
|
||||||
t.Fatalf("personalized machine mismatch: got %q", record.ID)
|
t.Fatalf("personalized machine mismatch: got %q", record.ID)
|
||||||
|
|
@ -409,9 +409,15 @@ func TestGetMachineReconcilesStartingMachineBeforeRunning(t *testing.T) {
|
||||||
if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 {
|
if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 {
|
||||||
t.Fatalf("personalized state mismatch: %#v", state)
|
t.Fatalf("personalized state mismatch: %#v", state)
|
||||||
}
|
}
|
||||||
return nil
|
guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
|
||||||
|
if guestSSHPublicKey == "" {
|
||||||
|
guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3"
|
||||||
|
}
|
||||||
|
return &guestReadyResult{
|
||||||
|
ReadyNonce: record.GuestReadyNonce,
|
||||||
|
GuestSSHPublicKey: guestSSHPublicKey,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
|
||||||
|
|
||||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||||
ID: "vm-starting",
|
ID: "vm-starting",
|
||||||
|
|
@ -455,9 +461,9 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error {
|
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) {
|
||||||
t.Fatalf("ListMachines should not reconcile guest personalization")
|
t.Fatalf("ListMachines should not reconcile guest personalization")
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
||||||
t.Fatalf("ListMachines should not read guest ssh public key")
|
t.Fatalf("ListMachines should not read guest ssh public key")
|
||||||
|
|
@ -492,7 +498,7 @@ func TestListMachinesDoesNotReconcileStartingMachines(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
|
func TestReconcileStartingMachineFailsWhenHandshakeFails(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
cfg := testConfig(root)
|
cfg := testConfig(root)
|
||||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||||
|
|
@ -511,11 +517,8 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) error {
|
hostDaemon.personalizeGuest = func(context.Context, *model.MachineRecord, firecracker.MachineState) (*guestReadyResult, error) {
|
||||||
return errors.New("vsock EOF")
|
return nil, errors.New("vsock EOF")
|
||||||
}
|
|
||||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
|
||||||
return "", errors.New("Permission denied")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||||
|
|
@ -542,14 +545,14 @@ func TestReconcileStartingMachineIgnoresPersonalizationFailures(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("get machine: %v", err)
|
t.Fatalf("get machine: %v", err)
|
||||||
}
|
}
|
||||||
if record.Phase != contracthost.MachinePhaseRunning {
|
if record.Phase != contracthost.MachinePhaseFailed {
|
||||||
t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseRunning)
|
t.Fatalf("machine phase = %q, want %q", record.Phase, contracthost.MachinePhaseFailed)
|
||||||
}
|
}
|
||||||
if record.GuestSSHPublicKey != "ssh-ed25519 AAAAExistingHostKey" {
|
if !strings.Contains(record.Error, "vsock EOF") {
|
||||||
t.Fatalf("guest ssh public key = %q, want preserved value", record.GuestSSHPublicKey)
|
t.Fatalf("failure reason = %q, want vsock error", record.Error)
|
||||||
}
|
}
|
||||||
if len(runtime.deleteCalls) != 0 {
|
if len(runtime.deleteCalls) != 1 {
|
||||||
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
|
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -756,7 +759,7 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
|
||||||
if response.Machine.ID != "restored" {
|
if response.Machine.ID != "restored" {
|
||||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||||
}
|
}
|
||||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||||
}
|
}
|
||||||
if runtime.bootCalls != 1 {
|
if runtime.bootCalls != 1 {
|
||||||
|
|
@ -1013,7 +1016,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
||||||
if response.Machine.ID != "restored" {
|
if response.Machine.ID != "restored" {
|
||||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||||
}
|
}
|
||||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||||
}
|
}
|
||||||
if runtime.bootCalls != 1 {
|
if runtime.bootCalls != 1 {
|
||||||
|
|
@ -1033,7 +1036,7 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("get restored machine: %v", err)
|
t.Fatalf("get restored machine: %v", err)
|
||||||
}
|
}
|
||||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
|
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
|
||||||
}
|
}
|
||||||
if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" {
|
if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" {
|
||||||
|
|
@ -1126,7 +1129,7 @@ func TestRestoreSnapshotBootsWithFreshNetworkWhenSourceNetworkInUseOnHost(t *tes
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("restore snapshot error = %v, want success", err)
|
t.Fatalf("restore snapshot error = %v, want success", err)
|
||||||
}
|
}
|
||||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||||
}
|
}
|
||||||
if runtime.bootCalls != 1 {
|
if runtime.bootCalls != 1 {
|
||||||
|
|
@ -1254,8 +1257,15 @@ func TestGuestKernelArgsRemovesPCIOffWhenPCIEnabled(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
|
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
|
||||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, _ firecracker.MachineState) (*guestReadyResult, error) {
|
||||||
return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil
|
guestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
|
||||||
|
if guestSSHPublicKey == "" {
|
||||||
|
guestSSHPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3"
|
||||||
|
}
|
||||||
|
return &guestReadyResult{
|
||||||
|
ReadyNonce: record.GuestReadyNonce,
|
||||||
|
GuestSSHPublicKey: guestSSHPublicKey,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -348,6 +348,42 @@ func (d *Daemon) writeBackendSSHPublicKey(privateKeyPath string, publicKeyPath s
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type guestSSHHostKeyPair struct {
|
||||||
|
PrivateKey []byte
|
||||||
|
PublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateGuestSSHHostKeyPair(ctx context.Context) (*guestSSHHostKeyPair, error) {
|
||||||
|
stagingDir, err := os.MkdirTemp("", "guest-ssh-hostkey-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create guest ssh host key staging dir: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = os.RemoveAll(stagingDir)
|
||||||
|
}()
|
||||||
|
|
||||||
|
privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key")
|
||||||
|
command := exec.CommandContext(ctx, "ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-C", "", "-f", privateKeyPath)
|
||||||
|
output, err := command.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate guest ssh host keypair: %w: %s", err, strings.TrimSpace(string(output)))
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := os.ReadFile(privateKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read guest ssh host private key %q: %w", privateKeyPath, err)
|
||||||
|
}
|
||||||
|
publicKey, err := os.ReadFile(privateKeyPath + ".pub")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read guest ssh host public key %q: %w", privateKeyPath+".pub", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &guestSSHHostKeyPair{
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
PublicKey: strings.TrimSpace(string(publicKey)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func fileExists(path string) bool {
|
func fileExists(path string) bool {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return err == nil
|
return err == nil
|
||||||
|
|
@ -441,6 +477,41 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func injectGuestSSHHostKey(ctx context.Context, imagePath string, keyPair *guestSSHHostKeyPair) error {
|
||||||
|
if keyPair == nil {
|
||||||
|
return fmt.Errorf("guest ssh host keypair is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(keyPair.PublicKey) == "" {
|
||||||
|
return fmt.Errorf("guest ssh host public key is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "guest-ssh-hostkey-*")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create guest ssh host key staging dir: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = os.RemoveAll(stagingDir)
|
||||||
|
}()
|
||||||
|
|
||||||
|
privateKeyPath := filepath.Join(stagingDir, "ssh_host_ed25519_key")
|
||||||
|
if err := os.WriteFile(privateKeyPath, keyPair.PrivateKey, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write guest ssh host private key staging file: %w", err)
|
||||||
|
}
|
||||||
|
if err := replaceExt4File(ctx, imagePath, privateKeyPath, "/etc/ssh/ssh_host_ed25519_key"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyPath := privateKeyPath + ".pub"
|
||||||
|
if err := os.WriteFile(publicKeyPath, []byte(strings.TrimSpace(keyPair.PublicKey)+"\n"), 0o644); err != nil {
|
||||||
|
return fmt.Errorf("write guest ssh host public key staging file: %w", err)
|
||||||
|
}
|
||||||
|
if err := replaceExt4File(ctx, imagePath, publicKeyPath, "/etc/ssh/ssh_host_ed25519_key.pub"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func injectMachineIdentity(ctx context.Context, imagePath string, machineID contracthost.MachineID) error {
|
func injectMachineIdentity(ctx context.Context, imagePath string, machineID contracthost.MachineID) error {
|
||||||
machineName := strings.TrimSpace(string(machineID))
|
machineName := strings.TrimSpace(string(machineID))
|
||||||
if machineName == "" {
|
if machineName == "" {
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos
|
||||||
if machineName == "" {
|
if machineName == "" {
|
||||||
return fmt.Errorf("machine id is required")
|
return fmt.Errorf("machine id is required")
|
||||||
}
|
}
|
||||||
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
|
mmds, err := d.guestMetadataSpec(machineID, guestConfig, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ type guestMetadataPayload struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
MachineID string `json:"machine_id"`
|
MachineID string `json:"machine_id"`
|
||||||
Hostname string `json:"hostname"`
|
Hostname string `json:"hostname"`
|
||||||
|
ReadyNonce string `json:"ready_nonce,omitempty"`
|
||||||
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
||||||
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
|
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
|
||||||
LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"`
|
LoginWebhook *contracthost.GuestLoginWebhook `json:"login_webhook,omitempty"`
|
||||||
|
|
@ -55,7 +56,7 @@ func guestHostname(machineID contracthost.MachineID, guestConfig *contracthost.G
|
||||||
return strings.TrimSpace(string(machineID))
|
return strings.TrimSpace(string(machineID))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) (*firecracker.MMDSSpec, error) {
|
func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig, readyNonce string) (*firecracker.MMDSSpec, error) {
|
||||||
name := guestHostname(machineID, guestConfig)
|
name := guestHostname(machineID, guestConfig)
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, fmt.Errorf("machine id is required")
|
return nil, fmt.Errorf("machine id is required")
|
||||||
|
|
@ -67,6 +68,7 @@ func (d *Daemon) guestMetadataSpec(machineID contracthost.MachineID, guestConfig
|
||||||
Version: defaultMMDSPayloadVersion,
|
Version: defaultMMDSPayloadVersion,
|
||||||
MachineID: name,
|
MachineID: name,
|
||||||
Hostname: name,
|
Hostname: name,
|
||||||
|
ReadyNonce: strings.TrimSpace(readyNonce),
|
||||||
AuthorizedKeys: nil,
|
AuthorizedKeys: nil,
|
||||||
TrustedUserCAKeys: nil,
|
TrustedUserCAKeys: nil,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
@ -13,20 +14,32 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultGuestPersonalizationVsockID = "microagent-personalizer"
|
defaultGuestPersonalizationVsockID = "microagent-personalizer"
|
||||||
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
|
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
|
||||||
defaultGuestPersonalizationVsockPort = uint32(1024)
|
defaultGuestPersonalizationVsockPort = uint32(1024)
|
||||||
defaultGuestPersonalizationTimeout = 2 * time.Second
|
defaultGuestPersonalizationTimeout = 15 * time.Second
|
||||||
|
guestPersonalizationRetryInterval = 100 * time.Millisecond
|
||||||
minGuestVsockCID = uint32(3)
|
minGuestVsockCID = uint32(3)
|
||||||
maxGuestVsockCID = uint32(1<<31 - 1)
|
maxGuestVsockCID = uint32(1<<31 - 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type guestPersonalizationResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
ReadyNonce string `json:"ready_nonce,omitempty"`
|
||||||
|
GuestSSHPublicKey string `json:"guest_ssh_public_key,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type guestReadyRequest struct {
|
||||||
|
ReadyNonce string `json:"ready_nonce,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec {
|
func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec {
|
||||||
return &firecracker.VsockSpec{
|
return &firecracker.VsockSpec{
|
||||||
ID: defaultGuestPersonalizationVsockID,
|
ID: defaultGuestPersonalizationVsockID,
|
||||||
|
|
@ -41,53 +54,46 @@ func guestVsockCID(machineID contracthost.MachineID) uint32 {
|
||||||
return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space
|
return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
|
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*guestReadyResult, error) {
|
||||||
if record == nil {
|
if record == nil {
|
||||||
return fmt.Errorf("machine record is required")
|
return nil, fmt.Errorf("machine record is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout)
|
personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig)
|
response, err := sendGuestPersonalization(personalizeCtx, state, guestReadyRequest{
|
||||||
|
ReadyNonce: strings.TrimSpace(record.GuestReadyNonce),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("wait for guest ready over vsock: %w", err)
|
||||||
}
|
}
|
||||||
envelope, ok := mmds.Data.(guestMetadataEnvelope)
|
if !strings.EqualFold(strings.TrimSpace(response.Status), "ok") {
|
||||||
if !ok {
|
message := strings.TrimSpace(response.Error)
|
||||||
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
|
if message == "" {
|
||||||
|
message = fmt.Sprintf("unexpected guest personalization status %q", strings.TrimSpace(response.Status))
|
||||||
|
}
|
||||||
|
return nil, errors.New(message)
|
||||||
}
|
}
|
||||||
|
return &guestReadyResult{
|
||||||
if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil {
|
ReadyNonce: strings.TrimSpace(response.ReadyNonce),
|
||||||
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err))
|
GuestSSHPublicKey: strings.TrimSpace(response.GuestSSHPublicKey),
|
||||||
}
|
}, nil
|
||||||
if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil {
|
|
||||||
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error {
|
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestReadyRequest) (*guestPersonalizationResponse, error) {
|
||||||
fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig)
|
|
||||||
if fallbackErr == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error {
|
|
||||||
payloadBytes, err := json.Marshal(payload)
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal guest personalization payload: %w", err)
|
return nil, fmt.Errorf("marshal guest personalization payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
vsockPath, err := guestVsockHostPath(state)
|
vsockPath, err := guestVsockHostPath(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath)
|
connection, err := dialGuestPersonalization(ctx, vsockPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = connection.Close()
|
_ = connection.Close()
|
||||||
|
|
@ -96,27 +102,28 @@ func sendGuestPersonalization(ctx context.Context, state firecracker.MachineStat
|
||||||
|
|
||||||
reader := bufio.NewReader(connection)
|
reader := bufio.NewReader(connection)
|
||||||
if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil {
|
if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil {
|
||||||
return fmt.Errorf("write vsock connect request: %w", err)
|
return nil, fmt.Errorf("write vsock connect request: %w", err)
|
||||||
}
|
}
|
||||||
response, err := reader.ReadString('\n')
|
response, err := reader.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read vsock connect response: %w", err)
|
return nil, fmt.Errorf("read vsock connect response: %w", err)
|
||||||
}
|
}
|
||||||
if !strings.HasPrefix(strings.TrimSpace(response), "OK ") {
|
if !strings.HasPrefix(strings.TrimSpace(response), "OK ") {
|
||||||
return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
|
return nil, fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := connection.Write(append(payloadBytes, '\n')); err != nil {
|
if _, err := connection.Write(append(payloadBytes, '\n')); err != nil {
|
||||||
return fmt.Errorf("write guest personalization payload: %w", err)
|
return nil, fmt.Errorf("write guest personalization payload: %w", err)
|
||||||
}
|
}
|
||||||
response, err = reader.ReadString('\n')
|
response, err = reader.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read guest personalization response: %w", err)
|
return nil, fmt.Errorf("read guest personalization response: %w", err)
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(response) != "OK" {
|
var payloadResponse guestPersonalizationResponse
|
||||||
return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response))
|
if err := json.Unmarshal([]byte(strings.TrimSpace(response)), &payloadResponse); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode guest personalization response %q: %w", strings.TrimSpace(response), err)
|
||||||
}
|
}
|
||||||
return nil
|
return &payloadResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func guestVsockHostPath(state firecracker.MachineState) (string, error) {
|
func guestVsockHostPath(state firecracker.MachineState) (string, error) {
|
||||||
|
|
@ -133,3 +140,25 @@ func setConnectionDeadline(ctx context.Context, connection net.Conn) {
|
||||||
}
|
}
|
||||||
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
|
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dialGuestPersonalization(ctx context.Context, vsockPath string) (net.Conn, error) {
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
for {
|
||||||
|
connection, err := dialer.DialContext(ctx, "unix", vsockPath)
|
||||||
|
if err == nil {
|
||||||
|
return connection, nil
|
||||||
|
}
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||||
|
return nil, fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(guestPersonalizationRetryInterval):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
|
|
@ -50,7 +52,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
||||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||||
}
|
}
|
||||||
if record.Phase == contracthost.MachinePhaseStarting {
|
if record.Phase == contracthost.MachinePhaseStarting {
|
||||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
reconciled, err := d.reconcileMachine(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &contracthost.GetMachineResponse{Machine: machineToContract(*reconciled)}, nil
|
||||||
}
|
}
|
||||||
if record.Phase != contracthost.MachinePhaseStopped {
|
if record.Phase != contracthost.MachinePhaseStopped {
|
||||||
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
|
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
|
||||||
|
|
@ -71,21 +77,38 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
|
var (
|
||||||
if err != nil {
|
systemVolume *model.VolumeRecord
|
||||||
return nil, err
|
artifact *model.ArtifactRecord
|
||||||
}
|
userVolumes []model.VolumeRecord
|
||||||
artifact, err := d.store.GetArtifact(ctx, record.Artifact)
|
readyNonce string
|
||||||
if err != nil {
|
)
|
||||||
return nil, err
|
group, groupCtx := errgroup.WithContext(ctx)
|
||||||
}
|
group.Go(func() error {
|
||||||
userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs)
|
var err error
|
||||||
if err != nil {
|
systemVolume, err = d.store.GetVolume(groupCtx, record.SystemVolumeID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
artifact, err = d.store.GetArtifact(groupCtx, record.Artifact)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
userVolumes, err = d.loadAttachableUserVolumes(groupCtx, id, record.UserVolumeIDs)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
readyNonce, err = newGuestReadyNonce()
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err := group.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
repairDirtyFilesystem(systemVolume.Path)
|
repairDirtyFilesystem(systemVolume.Path)
|
||||||
|
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig, readyNonce)
|
||||||
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path, record.GuestConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -100,7 +123,7 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
||||||
record.RuntimeHost = state.RuntimeHost
|
record.RuntimeHost = state.RuntimeHost
|
||||||
record.TapDevice = state.TapName
|
record.TapDevice = state.TapName
|
||||||
record.Ports = defaultMachinePorts()
|
record.Ports = defaultMachinePorts()
|
||||||
record.GuestSSHPublicKey = ""
|
record.GuestReadyNonce = readyNonce
|
||||||
record.Phase = contracthost.MachinePhaseStarting
|
record.Phase = contracthost.MachinePhaseStarting
|
||||||
record.Error = ""
|
record.Error = ""
|
||||||
record.PID = state.PID
|
record.PID = state.PID
|
||||||
|
|
@ -112,6 +135,11 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
record, err = d.completeMachineStartup(ctx, record, *state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||||
}
|
}
|
||||||
|
|
@ -376,44 +404,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if record.Phase == contracthost.MachinePhaseStarting {
|
if record.Phase == contracthost.MachinePhaseStarting {
|
||||||
if state.Phase != firecracker.PhaseRunning {
|
return d.completeMachineStartup(ctx, record, *state)
|
||||||
return d.failMachineStartup(ctx, record, state.Error)
|
|
||||||
}
|
|
||||||
ready, err := guestPortsReady(ctx, state.RuntimeHost, defaultMachinePorts())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ready {
|
|
||||||
return record, nil
|
|
||||||
}
|
|
||||||
if err := d.personalizeGuest(ctx, record, *state); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "warning: guest personalization for %q failed: %v\n", record.ID, err)
|
|
||||||
}
|
|
||||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "warning: read guest ssh public key for %q failed: %v\n", record.ID, err)
|
|
||||||
guestSSHPublicKey = record.GuestSSHPublicKey
|
|
||||||
}
|
|
||||||
record.RuntimeHost = state.RuntimeHost
|
|
||||||
record.TapDevice = state.TapName
|
|
||||||
record.Ports = defaultMachinePorts()
|
|
||||||
record.GuestSSHPublicKey = guestSSHPublicKey
|
|
||||||
record.Phase = contracthost.MachinePhaseRunning
|
|
||||||
record.Error = ""
|
|
||||||
record.PID = state.PID
|
|
||||||
record.SocketPath = state.SocketPath
|
|
||||||
record.StartedAt = state.StartedAt
|
|
||||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
|
||||||
return d.failMachineStartup(ctx, record, err.Error())
|
|
||||||
}
|
|
||||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
|
||||||
d.stopMachineRelays(record.ID)
|
|
||||||
return d.failMachineStartup(ctx, record, err.Error())
|
|
||||||
}
|
|
||||||
return record, nil
|
|
||||||
}
|
}
|
||||||
if state.Phase == firecracker.PhaseRunning {
|
if state.Phase == firecracker.PhaseRunning {
|
||||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||||
|
|
@ -450,7 +441,7 @@ func (d *Daemon) failMachineStartup(ctx context.Context, record *model.MachineRe
|
||||||
record.Phase = contracthost.MachinePhaseFailed
|
record.Phase = contracthost.MachinePhaseFailed
|
||||||
record.Error = strings.TrimSpace(failureReason)
|
record.Error = strings.TrimSpace(failureReason)
|
||||||
record.Ports = defaultMachinePorts()
|
record.Ports = defaultMachinePorts()
|
||||||
record.GuestSSHPublicKey = ""
|
record.GuestReadyNonce = ""
|
||||||
record.PID = 0
|
record.PID = 0
|
||||||
record.SocketPath = ""
|
record.SocketPath = ""
|
||||||
record.RuntimeHost = ""
|
record.RuntimeHost = ""
|
||||||
|
|
@ -511,6 +502,7 @@ func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRec
|
||||||
|
|
||||||
record.Phase = contracthost.MachinePhaseStopped
|
record.Phase = contracthost.MachinePhaseStopped
|
||||||
record.Error = ""
|
record.Error = ""
|
||||||
|
record.GuestReadyNonce = ""
|
||||||
record.PID = 0
|
record.PID = 0
|
||||||
record.SocketPath = ""
|
record.SocketPath = ""
|
||||||
record.RuntimeHost = ""
|
record.RuntimeHost = ""
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
||||||
assertOperationCount(t, baseStore, 1)
|
assertOperationCount(t, baseStore, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
func TestStartMachineTransitionsToRunningWithHandshake(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
cfg := testConfig(root)
|
cfg := testConfig(root)
|
||||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||||
|
|
@ -386,16 +386,16 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("StartMachine error = %v", err)
|
t.Fatalf("StartMachine error = %v", err)
|
||||||
}
|
}
|
||||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
|
t.Fatalf("response machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err := baseStore.GetMachine(context.Background(), "vm-start")
|
machine, err := baseStore.GetMachine(context.Background(), "vm-start")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("get machine: %v", err)
|
t.Fatalf("get machine: %v", err)
|
||||||
}
|
}
|
||||||
if machine.Phase != contracthost.MachinePhaseStarting {
|
if machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseStarting)
|
t.Fatalf("machine phase = %q, want %q", machine.Phase, contracthost.MachinePhaseRunning)
|
||||||
}
|
}
|
||||||
if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" {
|
if machine.RuntimeHost != "127.0.0.1" || machine.TapDevice != "fctap-start" {
|
||||||
t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice)
|
t.Fatalf("machine runtime state mismatch, got runtime_host=%q tap=%q", machine.RuntimeHost, machine.TapDevice)
|
||||||
|
|
@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
func TestRestoreSnapshotTransitionsToRunningWithHandshake(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
cfg := testConfig(root)
|
cfg := testConfig(root)
|
||||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||||
|
|
@ -510,8 +510,8 @@ func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("RestoreSnapshot returned error: %v", err)
|
t.Fatalf("RestoreSnapshot returned error: %v", err)
|
||||||
}
|
}
|
||||||
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
|
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
|
||||||
}
|
}
|
||||||
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil {
|
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil {
|
||||||
t.Fatalf("restored system volume record should exist: %v", err)
|
t.Fatalf("restored system volume record should exist: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -245,10 +245,33 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cleanupRestoreArtifacts()
|
defer cleanupRestoreArtifacts()
|
||||||
artifact, err := d.ensureArtifact(ctx, req.Artifact)
|
var (
|
||||||
if err != nil {
|
artifact *model.ArtifactRecord
|
||||||
|
guestHostKey *guestSSHHostKeyPair
|
||||||
|
readyNonce string
|
||||||
|
)
|
||||||
|
group, groupCtx := errgroup.WithContext(ctx)
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
artifact, err = d.ensureArtifact(groupCtx, req.Artifact)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ensure artifact for restore: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
guestHostKey, err = generateGuestSSHHostKeyPair(groupCtx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
group.Go(func() error {
|
||||||
|
var err error
|
||||||
|
readyNonce, err = newGuestReadyNonce()
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err := group.Wait(); err != nil {
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// COW-copy system disk from snapshot to new machine's disk dir.
|
// COW-copy system disk from snapshot to new machine's disk dir.
|
||||||
|
|
@ -280,6 +303,10 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return nil, fmt.Errorf("inject guest config for restore: %w", err)
|
return nil, fmt.Errorf("inject guest config for restore: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := injectGuestSSHHostKey(ctx, newSystemDiskPath, guestHostKey); err != nil {
|
||||||
|
clearOperation = true
|
||||||
|
return nil, fmt.Errorf("inject guest ssh host key for restore: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
type restoredUserVolume struct {
|
type restoredUserVolume struct {
|
||||||
ID contracthost.VolumeID
|
ID contracthost.VolumeID
|
||||||
|
|
@ -313,7 +340,7 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
Path: volume.Path,
|
Path: volume.Path,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig)
|
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, newSystemDiskPath, guestConfig, readyNonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return nil, fmt.Errorf("build machine spec for restore: %w", err)
|
return nil, fmt.Errorf("build machine spec for restore: %w", err)
|
||||||
|
|
@ -376,7 +403,8 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
RuntimeHost: machineState.RuntimeHost,
|
RuntimeHost: machineState.RuntimeHost,
|
||||||
TapDevice: machineState.TapName,
|
TapDevice: machineState.TapName,
|
||||||
Ports: defaultMachinePorts(),
|
Ports: defaultMachinePorts(),
|
||||||
GuestSSHPublicKey: "",
|
GuestSSHPublicKey: strings.TrimSpace(guestHostKey.PublicKey),
|
||||||
|
GuestReadyNonce: readyNonce,
|
||||||
Phase: contracthost.MachinePhaseStarting,
|
Phase: contracthost.MachinePhaseStarting,
|
||||||
PID: machineState.PID,
|
PID: machineState.PID,
|
||||||
SocketPath: machineState.SocketPath,
|
SocketPath: machineState.SocketPath,
|
||||||
|
|
@ -393,10 +421,15 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
record, err := d.completeMachineStartup(ctx, &machineRecord, *machineState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
removeMachineDiskDirOnFailure = false
|
removeMachineDiskDirOnFailure = false
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return &contracthost.RestoreSnapshotResponse{
|
return &contracthost.RestoreSnapshotResponse{
|
||||||
Machine: machineToContract(machineRecord),
|
Machine: machineToContract(*record),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
77
internal/daemon/startup.go
Normal file
77
internal/daemon/startup.go
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
|
)
|
||||||
|
|
||||||
|
type guestReadyResult struct {
|
||||||
|
ReadyNonce string
|
||||||
|
GuestSSHPublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGuestReadyNonce() (string, error) {
|
||||||
|
var bytes [16]byte
|
||||||
|
if _, err := rand.Read(bytes[:]); err != nil {
|
||||||
|
return "", fmt.Errorf("generate guest ready nonce: %w", err)
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes[:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) completeMachineStartup(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) (*model.MachineRecord, error) {
|
||||||
|
if record == nil {
|
||||||
|
return nil, fmt.Errorf("machine record is required")
|
||||||
|
}
|
||||||
|
if state.Phase != firecracker.PhaseRunning {
|
||||||
|
failureReason := strings.TrimSpace(state.Error)
|
||||||
|
if failureReason == "" {
|
||||||
|
failureReason = "machine did not reach running phase"
|
||||||
|
}
|
||||||
|
return d.failMachineStartup(ctx, record, failureReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
ready, err := d.personalizeGuest(ctx, record, state)
|
||||||
|
if err != nil {
|
||||||
|
return d.failMachineStartup(ctx, record, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNonce := strings.TrimSpace(record.GuestReadyNonce)
|
||||||
|
receivedNonce := strings.TrimSpace(ready.ReadyNonce)
|
||||||
|
if expectedNonce != "" && receivedNonce != expectedNonce {
|
||||||
|
return d.failMachineStartup(ctx, record, "guest ready nonce mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedGuestSSHPublicKey := strings.TrimSpace(record.GuestSSHPublicKey)
|
||||||
|
guestSSHPublicKey := strings.TrimSpace(ready.GuestSSHPublicKey)
|
||||||
|
if guestSSHPublicKey == "" {
|
||||||
|
if expectedGuestSSHPublicKey == "" {
|
||||||
|
return d.failMachineStartup(ctx, record, "guest ready response missing ssh host key")
|
||||||
|
}
|
||||||
|
guestSSHPublicKey = expectedGuestSSHPublicKey
|
||||||
|
}
|
||||||
|
if expectedGuestSSHPublicKey != "" && guestSSHPublicKey != expectedGuestSSHPublicKey {
|
||||||
|
return d.failMachineStartup(ctx, record, "guest ssh host key mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
record.RuntimeHost = state.RuntimeHost
|
||||||
|
record.TapDevice = state.TapName
|
||||||
|
record.Ports = defaultMachinePorts()
|
||||||
|
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||||
|
record.GuestReadyNonce = ""
|
||||||
|
record.Phase = contracthost.MachinePhaseRunning
|
||||||
|
record.Error = ""
|
||||||
|
record.PID = state.PID
|
||||||
|
record.SocketPath = state.SocketPath
|
||||||
|
record.StartedAt = state.StartedAt
|
||||||
|
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
@ -8,15 +8,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type vmConfig struct {
|
type vmConfig struct {
|
||||||
BootSource vmBootSource `json:"boot-source"`
|
BootSource vmBootSource `json:"boot-source"`
|
||||||
Drives []vmDrive `json:"drives"`
|
Drives []vmDrive `json:"drives"`
|
||||||
MachineConfig vmMachineConfig `json:"machine-config"`
|
MachineConfig vmMachineConfig `json:"machine-config"`
|
||||||
NetworkInterfaces []vmNetworkIface `json:"network-interfaces"`
|
NetworkInterfaces []vmNetworkIface `json:"network-interfaces"`
|
||||||
Vsock *vmVsock `json:"vsock,omitempty"`
|
Vsock *vmVsock `json:"vsock,omitempty"`
|
||||||
Logger *vmLogger `json:"logger,omitempty"`
|
Logger *vmLogger `json:"logger,omitempty"`
|
||||||
MMDSConfig *vmMMDSConfig `json:"mmds-config,omitempty"`
|
MMDSConfig *vmMMDSConfig `json:"mmds-config,omitempty"`
|
||||||
Entropy *vmEntropy `json:"entropy,omitempty"`
|
Entropy *vmEntropy `json:"entropy,omitempty"`
|
||||||
Serial *vmSerial `json:"serial,omitempty"`
|
Serial *vmSerial `json:"serial,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type vmBootSource struct {
|
type vmBootSource struct {
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ type MachineRecord struct {
|
||||||
TapDevice string
|
TapDevice string
|
||||||
Ports []contracthost.MachinePort
|
Ports []contracthost.MachinePort
|
||||||
GuestSSHPublicKey string
|
GuestSSHPublicKey string
|
||||||
|
GuestReadyNonce string
|
||||||
Phase contracthost.MachinePhase
|
Phase contracthost.MachinePhase
|
||||||
Error string
|
Error string
|
||||||
PID int
|
PID int
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue