mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 07:04:43 +00:00
feat: vsock mmds snapshot
This commit is contained in:
parent
39f8882c30
commit
07975fb459
13 changed files with 390 additions and 148 deletions
|
|
@ -7,10 +7,10 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) {
|
func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) {
|
||||||
|
|
@ -184,6 +184,7 @@ func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *mo
|
||||||
KernelArgs: defaultGuestKernelArgs,
|
KernelArgs: defaultGuestKernelArgs,
|
||||||
Drives: drives,
|
Drives: drives,
|
||||||
MMDS: mmds,
|
MMDS: mmds,
|
||||||
|
Vsock: guestVsockSpec(machineID),
|
||||||
}
|
}
|
||||||
if err := spec.Validate(); err != nil {
|
if err := spec.Validate(); err != nil {
|
||||||
return firecracker.MachineSpec{}, err
|
return firecracker.MachineSpec{}, err
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,11 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
|
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -22,8 +23,6 @@ const (
|
||||||
defaultVNCPort = uint16(6080)
|
defaultVNCPort = uint16(6080)
|
||||||
defaultCopyBufferSize = 1024 * 1024
|
defaultCopyBufferSize = 1024 * 1024
|
||||||
defaultGuestDialTimeout = 500 * time.Millisecond
|
defaultGuestDialTimeout = 500 * time.Millisecond
|
||||||
defaultGuestReadyPollInterval = 100 * time.Millisecond
|
|
||||||
defaultGuestReadyTimeout = 30 * time.Second
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Runtime interface {
|
type Runtime interface {
|
||||||
|
|
@ -34,6 +33,7 @@ type Runtime interface {
|
||||||
Resume(context.Context, firecracker.MachineState) error
|
Resume(context.Context, firecracker.MachineState) error
|
||||||
CreateSnapshot(context.Context, firecracker.MachineState, firecracker.SnapshotPaths) error
|
CreateSnapshot(context.Context, firecracker.MachineState, firecracker.SnapshotPaths) error
|
||||||
RestoreBoot(context.Context, firecracker.SnapshotLoadSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error)
|
RestoreBoot(context.Context, firecracker.SnapshotLoadSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error)
|
||||||
|
PutMMDS(context.Context, firecracker.MachineState, any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Daemon struct {
|
type Daemon struct {
|
||||||
|
|
@ -44,6 +44,7 @@ type Daemon struct {
|
||||||
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error
|
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error
|
||||||
readGuestSSHPublicKey func(context.Context, string) (string, error)
|
readGuestSSHPublicKey func(context.Context, string) (string, error)
|
||||||
syncGuestFilesystem func(context.Context, string) error
|
syncGuestFilesystem func(context.Context, string) error
|
||||||
|
personalizeGuest func(context.Context, *model.MachineRecord, firecracker.MachineState) error
|
||||||
|
|
||||||
locksMu sync.Mutex
|
locksMu sync.Mutex
|
||||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||||
|
|
@ -78,6 +79,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
reconfigureGuestIdentity: nil,
|
reconfigureGuestIdentity: nil,
|
||||||
readGuestSSHPublicKey: nil,
|
readGuestSSHPublicKey: nil,
|
||||||
|
personalizeGuest: nil,
|
||||||
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
||||||
artifactLocks: make(map[string]*sync.Mutex),
|
artifactLocks: make(map[string]*sync.Mutex),
|
||||||
machineRelayListeners: make(map[string]net.Listener),
|
machineRelayListeners: make(map[string]net.Listener),
|
||||||
|
|
@ -86,6 +88,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
||||||
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
||||||
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
|
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
|
||||||
daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH
|
daemon.syncGuestFilesystem = daemon.syncGuestFilesystemOverSSH
|
||||||
|
daemon.personalizeGuest = daemon.personalizeGuestConfig
|
||||||
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,11 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
|
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type fakeRuntime struct {
|
type fakeRuntime struct {
|
||||||
|
|
@ -27,6 +27,7 @@ type fakeRuntime struct {
|
||||||
deleteCalls []firecracker.MachineState
|
deleteCalls []firecracker.MachineState
|
||||||
lastSpec firecracker.MachineSpec
|
lastSpec firecracker.MachineSpec
|
||||||
lastLoadSpec firecracker.SnapshotLoadSpec
|
lastLoadSpec firecracker.SnapshotLoadSpec
|
||||||
|
mmdsWrites []any
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) {
|
func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) {
|
||||||
|
|
@ -64,6 +65,11 @@ func (f *fakeRuntime) RestoreBoot(_ context.Context, spec firecracker.SnapshotLo
|
||||||
return &f.bootState, nil
|
return &f.bootState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fakeRuntime) PutMMDS(_ context.Context, _ firecracker.MachineState, data any) error {
|
||||||
|
f.mmdsWrites = append(f.mmdsWrites, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
cfg := testConfig(root)
|
cfg := testConfig(root)
|
||||||
|
|
@ -173,6 +179,15 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
||||||
if runtime.lastSpec.MMDS == nil {
|
if runtime.lastSpec.MMDS == nil {
|
||||||
t.Fatalf("expected MMDS configuration on machine spec")
|
t.Fatalf("expected MMDS configuration on machine spec")
|
||||||
}
|
}
|
||||||
|
if runtime.lastSpec.Vsock == nil {
|
||||||
|
t.Fatalf("expected vsock configuration on machine spec")
|
||||||
|
}
|
||||||
|
if runtime.lastSpec.Vsock.ID != defaultGuestPersonalizationVsockID {
|
||||||
|
t.Fatalf("vsock id mismatch: got %q", runtime.lastSpec.Vsock.ID)
|
||||||
|
}
|
||||||
|
if runtime.lastSpec.Vsock.CID < minGuestVsockCID {
|
||||||
|
t.Fatalf("vsock cid mismatch: got %d", runtime.lastSpec.Vsock.CID)
|
||||||
|
}
|
||||||
if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 {
|
if runtime.lastSpec.MMDS.Version != firecracker.MMDSVersionV2 {
|
||||||
t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version)
|
t.Fatalf("mmds version mismatch: got %q", runtime.lastSpec.MMDS.Version)
|
||||||
}
|
}
|
||||||
|
|
@ -286,6 +301,70 @@ func TestStopMachineSyncsGuestFilesystemBeforeDelete(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReconcileStartingMachinePersonalizesBeforeRunning(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
cfg := testConfig(root)
|
||||||
|
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create file store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||||
|
defer func() { _ = sshListener.Close() }()
|
||||||
|
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||||
|
defer func() { _ = vncListener.Close() }()
|
||||||
|
|
||||||
|
startedAt := time.Unix(1700000100, 0).UTC()
|
||||||
|
runtime := &fakeRuntime{}
|
||||||
|
hostDaemon, err := New(cfg, fileStore, runtime)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create daemon: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { hostDaemon.stopMachineRelays("vm-starting") })
|
||||||
|
|
||||||
|
personalized := false
|
||||||
|
hostDaemon.personalizeGuest = func(_ context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
|
||||||
|
personalized = true
|
||||||
|
if record.ID != "vm-starting" {
|
||||||
|
t.Fatalf("personalized machine mismatch: got %q", record.ID)
|
||||||
|
}
|
||||||
|
if state.RuntimeHost != "127.0.0.1" || state.PID != 4321 {
|
||||||
|
t.Fatalf("personalized state mismatch: %#v", state)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
|
if err := fileStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||||
|
ID: "vm-starting",
|
||||||
|
SystemVolumeID: "vm-starting-system",
|
||||||
|
RuntimeHost: "127.0.0.1",
|
||||||
|
TapDevice: "fctap-starting",
|
||||||
|
Ports: defaultMachinePorts(),
|
||||||
|
Phase: contracthost.MachinePhaseStarting,
|
||||||
|
PID: 4321,
|
||||||
|
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-starting", "root", "run", "firecracker.sock"),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
StartedAt: &startedAt,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create machine: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := hostDaemon.GetMachine(context.Background(), "vm-starting")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetMachine returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !personalized {
|
||||||
|
t.Fatalf("guest personalization was not called")
|
||||||
|
}
|
||||||
|
if response.Machine.Phase != contracthost.MachinePhaseRunning {
|
||||||
|
t.Fatalf("machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseRunning)
|
||||||
|
}
|
||||||
|
if response.Machine.GuestSSHPublicKey == "" {
|
||||||
|
t.Fatalf("guest ssh public key should be recorded after convergence")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
|
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
cfg := testConfig(root)
|
cfg := testConfig(root)
|
||||||
|
|
@ -406,6 +485,9 @@ func TestRestoreSnapshotFallsBackToLocalSnapshotNetwork(t *testing.T) {
|
||||||
if response.Machine.ID != "restored" {
|
if response.Machine.ID != "restored" {
|
||||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||||
}
|
}
|
||||||
|
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||||
|
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||||
|
}
|
||||||
if runtime.restoreCalls != 1 {
|
if runtime.restoreCalls != 1 {
|
||||||
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
|
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
|
||||||
}
|
}
|
||||||
|
|
@ -462,13 +544,8 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
var reconfiguredHost string
|
|
||||||
var reconfiguredMachine contracthost.MachineID
|
|
||||||
var reconfiguredConfig *contracthost.GuestConfig
|
|
||||||
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error {
|
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error {
|
||||||
reconfiguredHost = host
|
t.Fatalf("restore snapshot should not synchronously reconfigure guest identity, host=%q machine=%q guest_config=%#v", host, machineID, guestConfig)
|
||||||
reconfiguredMachine = machineID
|
|
||||||
reconfiguredConfig = cloneGuestConfig(guestConfig)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -509,6 +586,9 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
||||||
if response.Machine.ID != "restored" {
|
if response.Machine.ID != "restored" {
|
||||||
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
t.Fatalf("restored machine id mismatch: got %q", response.Machine.ID)
|
||||||
}
|
}
|
||||||
|
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||||
|
t.Fatalf("restored machine phase mismatch: got %q", response.Machine.Phase)
|
||||||
|
}
|
||||||
if runtime.restoreCalls != 1 {
|
if runtime.restoreCalls != 1 {
|
||||||
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
|
t.Fatalf("restore boot call count mismatch: got %d want 1", runtime.restoreCalls)
|
||||||
}
|
}
|
||||||
|
|
@ -527,20 +607,17 @@ func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
|
||||||
}), "kernel")) {
|
}), "kernel")) {
|
||||||
t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath)
|
t.Fatalf("restore boot kernel path mismatch: got %q", runtime.lastLoadSpec.KernelImagePath)
|
||||||
}
|
}
|
||||||
if reconfiguredHost != "127.0.0.1" || reconfiguredMachine != "restored" {
|
|
||||||
t.Fatalf("guest identity reconfigure mismatch: host=%q machine=%q", reconfiguredHost, reconfiguredMachine)
|
|
||||||
}
|
|
||||||
if reconfiguredConfig == nil || reconfiguredConfig.Hostname != "restored-shell" {
|
|
||||||
t.Fatalf("guest identity hostname mismatch: %#v", reconfiguredConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
machine, err := fileStore.GetMachine(context.Background(), "restored")
|
machine, err := fileStore.GetMachine(context.Background(), "restored")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("get restored machine: %v", err)
|
t.Fatalf("get restored machine: %v", err)
|
||||||
}
|
}
|
||||||
if machine.Phase != contracthost.MachinePhaseRunning {
|
if machine.Phase != contracthost.MachinePhaseStarting {
|
||||||
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
|
t.Fatalf("restored machine phase mismatch: got %q", machine.Phase)
|
||||||
}
|
}
|
||||||
|
if machine.GuestConfig == nil || machine.GuestConfig.Hostname != "restored-shell" {
|
||||||
|
t.Fatalf("stored guest config mismatch: %#v", machine.GuestConfig)
|
||||||
|
}
|
||||||
if len(machine.UserVolumeIDs) != 1 {
|
if len(machine.UserVolumeIDs) != 1 {
|
||||||
t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs)
|
t.Fatalf("restored machine user volumes mismatch: got %#v", machine.UserVolumeIDs)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package daemon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
@ -19,10 +20,24 @@ func (d *Daemon) reconfigureGuestIdentityOverSSH(ctx context.Context, runtimeHos
|
||||||
if machineName == "" {
|
if machineName == "" {
|
||||||
return fmt.Errorf("machine id is required")
|
return fmt.Errorf("machine id is required")
|
||||||
}
|
}
|
||||||
|
mmds, err := d.guestMetadataSpec(machineID, guestConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
envelope, ok := mmds.Data.(guestMetadataEnvelope)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
|
||||||
|
}
|
||||||
|
payloadBytes, err := json.Marshal(envelope.Latest.MetaData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal guest metadata payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
privateKeyPath := d.backendSSHPrivateKeyPath()
|
privateKeyPath := d.backendSSHPrivateKeyPath()
|
||||||
remoteScript := fmt.Sprintf(`set -euo pipefail
|
remoteScript := fmt.Sprintf(`set -euo pipefail
|
||||||
machine_name=%s
|
payload=%s
|
||||||
|
install -d -m 0755 /etc/microagent
|
||||||
|
machine_name="$(printf '%%s' "$payload" | jq -r '.hostname // .machine_id // empty')"
|
||||||
printf '%%s\n' "$machine_name" >/etc/microagent/machine-name
|
printf '%%s\n' "$machine_name" >/etc/microagent/machine-name
|
||||||
printf '%%s\n' "$machine_name" >/etc/hostname
|
printf '%%s\n' "$machine_name" >/etc/hostname
|
||||||
cat >/etc/hosts <<EOF
|
cat >/etc/hosts <<EOF
|
||||||
|
|
@ -33,7 +48,25 @@ ff02::1 ip6-allnodes
|
||||||
ff02::2 ip6-allrouters
|
ff02::2 ip6-allrouters
|
||||||
EOF
|
EOF
|
||||||
hostname "$machine_name" >/dev/null 2>&1 || true
|
hostname "$machine_name" >/dev/null 2>&1 || true
|
||||||
`, strconv.Quote(machineName))
|
if printf '%%s' "$payload" | jq -e '.authorized_keys | length > 0' >/dev/null 2>&1; then
|
||||||
|
install -d -m 0700 -o node -g node /home/node/.ssh
|
||||||
|
printf '%%s' "$payload" | jq -r '.authorized_keys[]' >/home/node/.ssh/authorized_keys
|
||||||
|
chmod 0600 /home/node/.ssh/authorized_keys
|
||||||
|
chown node:node /home/node/.ssh/authorized_keys
|
||||||
|
printf '%%s' "$payload" | jq -r '.authorized_keys[]' >/etc/microagent/authorized_keys
|
||||||
|
chmod 0600 /etc/microagent/authorized_keys
|
||||||
|
else
|
||||||
|
rm -f /home/node/.ssh/authorized_keys /etc/microagent/authorized_keys
|
||||||
|
fi
|
||||||
|
if printf '%%s' "$payload" | jq -e '.trusted_user_ca_keys | length > 0' >/dev/null 2>&1; then
|
||||||
|
printf '%%s' "$payload" | jq -r '.trusted_user_ca_keys[]' >/etc/microagent/trusted_user_ca_keys
|
||||||
|
chmod 0644 /etc/microagent/trusted_user_ca_keys
|
||||||
|
else
|
||||||
|
rm -f /etc/microagent/trusted_user_ca_keys
|
||||||
|
fi
|
||||||
|
printf '%%s' "$payload" | jq '{authorized_keys, trusted_user_ca_keys, login_webhook}' >/etc/microagent/guest-config.json
|
||||||
|
chmod 0600 /etc/microagent/guest-config.json
|
||||||
|
`, strconv.Quote(string(payloadBytes)))
|
||||||
|
|
||||||
cmd := exec.CommandContext(
|
cmd := exec.CommandContext(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
@ -43,6 +76,7 @@ hostname "$machine_name" >/dev/null 2>&1 || true
|
||||||
"-o", "UserKnownHostsFile=/dev/null",
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
"-o", "IdentitiesOnly=yes",
|
"-o", "IdentitiesOnly=yes",
|
||||||
"-o", "BatchMode=yes",
|
"-o", "BatchMode=yes",
|
||||||
|
"-o", "ConnectTimeout=2",
|
||||||
"-p", strconv.Itoa(int(defaultSSHPort)),
|
"-p", strconv.Itoa(int(defaultSSHPort)),
|
||||||
"node@"+runtimeHost,
|
"node@"+runtimeHost,
|
||||||
"sudo bash -lc "+shellSingleQuote(remoteScript),
|
"sudo bash -lc "+shellSingleQuote(remoteScript),
|
||||||
|
|
@ -68,6 +102,7 @@ func (d *Daemon) syncGuestFilesystemOverSSH(ctx context.Context, runtimeHost str
|
||||||
"-o", "UserKnownHostsFile=/dev/null",
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
"-o", "IdentitiesOnly=yes",
|
"-o", "IdentitiesOnly=yes",
|
||||||
"-o", "BatchMode=yes",
|
"-o", "BatchMode=yes",
|
||||||
|
"-o", "ConnectTimeout=2",
|
||||||
"-p", strconv.Itoa(int(defaultSSHPort)),
|
"-p", strconv.Itoa(int(defaultSSHPort)),
|
||||||
"node@"+runtimeHost,
|
"node@"+runtimeHost,
|
||||||
"sudo bash -lc "+shellSingleQuote("sync"),
|
"sudo bash -lc "+shellSingleQuote("sync"),
|
||||||
|
|
|
||||||
135
internal/daemon/guest_personalization.go
Normal file
135
internal/daemon/guest_personalization.go
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultGuestPersonalizationVsockID = "microagent-personalizer"
|
||||||
|
defaultGuestPersonalizationVsockName = "microagent-personalizer.vsock"
|
||||||
|
defaultGuestPersonalizationVsockPort = uint32(1024)
|
||||||
|
defaultGuestPersonalizationTimeout = 2 * time.Second
|
||||||
|
minGuestVsockCID = uint32(3)
|
||||||
|
maxGuestVsockCID = uint32(1<<31 - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
func guestVsockSpec(machineID contracthost.MachineID) *firecracker.VsockSpec {
|
||||||
|
return &firecracker.VsockSpec{
|
||||||
|
ID: defaultGuestPersonalizationVsockID,
|
||||||
|
CID: guestVsockCID(machineID),
|
||||||
|
Path: defaultGuestPersonalizationVsockName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func guestVsockCID(machineID contracthost.MachineID) uint32 {
|
||||||
|
sum := sha256.Sum256([]byte(machineID))
|
||||||
|
space := maxGuestVsockCID - minGuestVsockCID + 1
|
||||||
|
return minGuestVsockCID + binary.BigEndian.Uint32(sum[:4])%space
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) personalizeGuestConfig(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState) error {
|
||||||
|
if record == nil {
|
||||||
|
return fmt.Errorf("machine record is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
personalizeCtx, cancel := context.WithTimeout(ctx, defaultGuestPersonalizationTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
mmds, err := d.guestMetadataSpec(record.ID, record.GuestConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
envelope, ok := mmds.Data.(guestMetadataEnvelope)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("guest metadata payload has unexpected type %T", mmds.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.runtime.PutMMDS(personalizeCtx, state, mmds.Data); err != nil {
|
||||||
|
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("reseed guest mmds: %w", err))
|
||||||
|
}
|
||||||
|
if err := sendGuestPersonalization(personalizeCtx, state, envelope.Latest.MetaData); err != nil {
|
||||||
|
return d.personalizeGuestConfigViaSSH(ctx, record, state, fmt.Errorf("apply guest config over vsock: %w", err))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) personalizeGuestConfigViaSSH(ctx context.Context, record *model.MachineRecord, state firecracker.MachineState, primaryErr error) error {
|
||||||
|
fallbackErr := d.reconfigureGuestIdentity(ctx, state.RuntimeHost, record.ID, record.GuestConfig)
|
||||||
|
if fallbackErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w; ssh fallback failed: %v", primaryErr, fallbackErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendGuestPersonalization(ctx context.Context, state firecracker.MachineState, payload guestMetadataPayload) error {
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal guest personalization payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vsockPath, err := guestVsockHostPath(state)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
connection, err := (&net.Dialer{}).DialContext(ctx, "unix", vsockPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dial guest personalization vsock %q: %w", vsockPath, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = connection.Close()
|
||||||
|
}()
|
||||||
|
setConnectionDeadline(ctx, connection)
|
||||||
|
|
||||||
|
reader := bufio.NewReader(connection)
|
||||||
|
if _, err := fmt.Fprintf(connection, "CONNECT %d\n", defaultGuestPersonalizationVsockPort); err != nil {
|
||||||
|
return fmt.Errorf("write vsock connect request: %w", err)
|
||||||
|
}
|
||||||
|
response, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read vsock connect response: %w", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(response), "OK ") {
|
||||||
|
return fmt.Errorf("unexpected vsock connect response %q", strings.TrimSpace(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := connection.Write(append(payloadBytes, '\n')); err != nil {
|
||||||
|
return fmt.Errorf("write guest personalization payload: %w", err)
|
||||||
|
}
|
||||||
|
response, err = reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read guest personalization response: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(response) != "OK" {
|
||||||
|
return fmt.Errorf("unexpected guest personalization response %q", strings.TrimSpace(response))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func guestVsockHostPath(state firecracker.MachineState) (string, error) {
|
||||||
|
if state.PID < 1 {
|
||||||
|
return "", fmt.Errorf("firecracker pid is required for guest vsock host path")
|
||||||
|
}
|
||||||
|
return filepath.Join("/proc", strconv.Itoa(state.PID), "root", "run", defaultGuestPersonalizationVsockName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setConnectionDeadline(ctx context.Context, connection net.Conn) {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
_ = connection.SetDeadline(deadline)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = connection.SetDeadline(time.Now().Add(defaultGuestPersonalizationTimeout))
|
||||||
|
}
|
||||||
|
|
@ -9,10 +9,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
|
func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
|
||||||
|
|
@ -387,6 +387,9 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
||||||
if !ready {
|
if !ready {
|
||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
if err := d.personalizeGuest(ctx, record, *state); err != nil {
|
||||||
|
return d.failMachineStartup(ctx, record, err.Error())
|
||||||
|
}
|
||||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return d.failMachineStartup(ctx, record, err.Error())
|
return d.failMachineStartup(ctx, record, err.Error())
|
||||||
|
|
|
||||||
|
|
@ -6,51 +6,10 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
)
|
)
|
||||||
|
|
||||||
func waitForGuestReady(ctx context.Context, host string, ports []contracthost.MachinePort) error {
|
|
||||||
host = strings.TrimSpace(host)
|
|
||||||
if host == "" {
|
|
||||||
return fmt.Errorf("guest runtime host is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
waitContext, cancel := context.WithTimeout(ctx, defaultGuestReadyTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
for _, port := range ports {
|
|
||||||
if err := waitForGuestPort(waitContext, host, port); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForGuestPort(ctx context.Context, host string, port contracthost.MachinePort) error {
|
|
||||||
address := net.JoinHostPort(host, strconv.Itoa(int(port.Port)))
|
|
||||||
ticker := time.NewTicker(defaultGuestReadyPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
for {
|
|
||||||
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
|
|
||||||
ready, err := guestPortReady(probeCtx, host, port)
|
|
||||||
cancel()
|
|
||||||
if err == nil && ready {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
lastErr = err
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return fmt.Errorf("wait for guest port %q on %s: %w (last_err=%v)", port.Name, address, ctx.Err(), lastErr)
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) {
|
func guestPortsReady(ctx context.Context, host string, ports []contracthost.MachinePort) (bool, error) {
|
||||||
host = strings.TrimSpace(host)
|
host = strings.TrimSpace(host)
|
||||||
if host == "" {
|
if host == "" {
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
hoststore "github.com/getcompanion-ai/computer-host/internal/store"
|
hoststore "github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type blockingPublishedPortStore struct {
|
type blockingPublishedPortStore struct {
|
||||||
|
|
@ -408,7 +408,7 @@ func TestStartMachineTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *testing.T) {
|
func TestRestoreSnapshotTransitionsToStartingWithoutRelayAllocation(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
cfg := testConfig(root)
|
cfg := testConfig(root)
|
||||||
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
baseStore, err := hoststore.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||||
|
|
@ -421,15 +421,6 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
|
||||||
extraMachines: exhaustedMachineRelayRecords(),
|
extraMachines: exhaustedMachineRelayRecords(),
|
||||||
}
|
}
|
||||||
|
|
||||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
|
||||||
defer func() {
|
|
||||||
_ = sshListener.Close()
|
|
||||||
}()
|
|
||||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
|
||||||
defer func() {
|
|
||||||
_ = vncListener.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
startedAt := time.Unix(1700000300, 0).UTC()
|
startedAt := time.Unix(1700000300, 0).UTC()
|
||||||
runtime := &fakeRuntime{
|
runtime := &fakeRuntime{
|
||||||
bootState: firecracker.MachineState{
|
bootState: firecracker.MachineState{
|
||||||
|
|
@ -448,7 +439,10 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID, *contracthost.GuestConfig) error { return nil }
|
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID, guestConfig *contracthost.GuestConfig) error {
|
||||||
|
t.Fatalf("restore snapshot should not synchronously reconfigure guest identity, host=%q machine=%q guest_config=%#v", host, machineID, guestConfig)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
||||||
kernelPath := filepath.Join(root, "artifact-kernel")
|
kernelPath := filepath.Join(root, "artifact-kernel")
|
||||||
|
|
@ -509,7 +503,7 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
_, err = hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{
|
response, err := hostDaemon.RestoreSnapshot(context.Background(), "snap-exhausted", contracthost.RestoreSnapshotRequest{
|
||||||
MachineID: "restored-exhausted",
|
MachineID: "restored-exhausted",
|
||||||
Artifact: contracthost.ArtifactRef{
|
Artifact: contracthost.ArtifactRef{
|
||||||
KernelImageURL: server.URL + "/kernel",
|
KernelImageURL: server.URL + "/kernel",
|
||||||
|
|
@ -526,18 +520,20 @@ func TestRestoreSnapshotDeletesSystemVolumeRecordWhenRelayAllocationFails(t *tes
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err == nil || !strings.Contains(err.Error(), "allocate relay ports for restored machine") {
|
if err != nil {
|
||||||
t.Fatalf("RestoreSnapshot error = %v, want relay allocation failure", err)
|
t.Fatalf("RestoreSnapshot returned error: %v", err)
|
||||||
}
|
}
|
||||||
|
if response.Machine.Phase != contracthost.MachinePhaseStarting {
|
||||||
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); !errors.Is(err, hoststore.ErrNotFound) {
|
t.Fatalf("restored machine phase = %q, want %q", response.Machine.Phase, contracthost.MachinePhaseStarting)
|
||||||
t.Fatalf("restored system volume record should be deleted, get err = %v", err)
|
|
||||||
}
|
}
|
||||||
if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); !os.IsNotExist(err) {
|
if _, err := baseStore.GetVolume(context.Background(), "restored-exhausted-system"); err != nil {
|
||||||
t.Fatalf("restored system disk should be removed, stat err = %v", err)
|
t.Fatalf("restored system volume record should exist: %v", err)
|
||||||
}
|
}
|
||||||
if len(runtime.deleteCalls) != 1 {
|
if _, err := os.Stat(hostDaemon.systemVolumePath("restored-exhausted")); err != nil {
|
||||||
t.Fatalf("runtime delete calls = %d, want 1", len(runtime.deleteCalls))
|
t.Fatalf("restored system disk should exist: %v", err)
|
||||||
|
}
|
||||||
|
if len(runtime.deleteCalls) != 0 {
|
||||||
|
t.Fatalf("runtime delete calls = %d, want 0", len(runtime.deleteCalls))
|
||||||
}
|
}
|
||||||
assertOperationCount(t, baseStore, 0)
|
assertOperationCount(t, baseStore, 0)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/store"
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
|
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
|
||||||
|
|
@ -332,6 +332,9 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
restoredDrivePaths[driveID] = volumePath
|
restoredDrivePaths[driveID] = volumePath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not force vsock_override on restore: Firecracker rejects it for old
|
||||||
|
// snapshots without a vsock device, and the jailed /run path already
|
||||||
|
// relocates safely for snapshots created with the new vsock-backed guest.
|
||||||
loadSpec := firecracker.SnapshotLoadSpec{
|
loadSpec := firecracker.SnapshotLoadSpec{
|
||||||
ID: firecracker.MachineID(req.MachineID),
|
ID: firecracker.MachineID(req.MachineID),
|
||||||
SnapshotPath: vmstateArtifact.LocalPath,
|
SnapshotPath: vmstateArtifact.LocalPath,
|
||||||
|
|
@ -349,27 +352,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
return nil, fmt.Errorf("restore boot: %w", err)
|
return nil, fmt.Errorf("restore boot: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for guest to become ready
|
|
||||||
if err := waitForGuestReady(ctx, machineState.RuntimeHost, defaultMachinePorts()); err != nil {
|
|
||||||
_ = d.runtime.Delete(ctx, *machineState)
|
|
||||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
|
||||||
clearOperation = true
|
|
||||||
return nil, fmt.Errorf("wait for restored guest ready: %w", err)
|
|
||||||
}
|
|
||||||
if err := d.reconfigureGuestIdentity(ctx, machineState.RuntimeHost, req.MachineID, guestConfig); err != nil {
|
|
||||||
_ = d.runtime.Delete(ctx, *machineState)
|
|
||||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
|
||||||
clearOperation = true
|
|
||||||
return nil, fmt.Errorf("reconfigure restored guest identity: %w", err)
|
|
||||||
}
|
|
||||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost)
|
|
||||||
if err != nil {
|
|
||||||
_ = d.runtime.Delete(ctx, *machineState)
|
|
||||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
|
||||||
clearOperation = true
|
|
||||||
return nil, fmt.Errorf("read restored guest ssh host key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
systemVolumeID := d.systemVolumeID(req.MachineID)
|
systemVolumeID := d.systemVolumeID(req.MachineID)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
|
@ -419,38 +401,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
RuntimeHost: machineState.RuntimeHost,
|
RuntimeHost: machineState.RuntimeHost,
|
||||||
TapDevice: machineState.TapName,
|
TapDevice: machineState.TapName,
|
||||||
Ports: defaultMachinePorts(),
|
Ports: defaultMachinePorts(),
|
||||||
GuestSSHPublicKey: guestSSHPublicKey,
|
GuestSSHPublicKey: "",
|
||||||
Phase: contracthost.MachinePhaseRunning,
|
Phase: contracthost.MachinePhaseStarting,
|
||||||
PID: machineState.PID,
|
PID: machineState.PID,
|
||||||
SocketPath: machineState.SocketPath,
|
SocketPath: machineState.SocketPath,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
StartedAt: machineState.StartedAt,
|
StartedAt: machineState.StartedAt,
|
||||||
}
|
}
|
||||||
d.relayAllocMu.Lock()
|
|
||||||
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
|
||||||
var vncRelayPort uint16
|
|
||||||
if err == nil {
|
|
||||||
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
|
||||||
}
|
|
||||||
d.relayAllocMu.Unlock()
|
|
||||||
if err != nil {
|
|
||||||
d.stopMachineRelays(machineRecord.ID)
|
|
||||||
for _, restoredVolumeID := range restoredUserVolumeIDs {
|
|
||||||
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
|
|
||||||
}
|
|
||||||
_ = d.store.DeleteVolume(context.Background(), systemVolumeID)
|
|
||||||
_ = d.runtime.Delete(ctx, *machineState)
|
|
||||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
|
||||||
clearOperation = true
|
|
||||||
return nil, fmt.Errorf("allocate relay ports for restored machine: %w", err)
|
|
||||||
}
|
|
||||||
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
|
||||||
startedRelays := true
|
|
||||||
defer func() {
|
|
||||||
if startedRelays {
|
|
||||||
d.stopMachineRelays(machineRecord.ID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
|
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
|
||||||
for _, restoredVolumeID := range restoredUserVolumeIDs {
|
for _, restoredVolumeID := range restoredUserVolumeIDs {
|
||||||
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
|
_ = d.store.DeleteVolume(context.Background(), restoredVolumeID)
|
||||||
|
|
@ -462,7 +419,6 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
startedRelays = false
|
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return &contracthost.RestoreSnapshotResponse{
|
return &contracthost.RestoreSnapshotResponse{
|
||||||
Machine: machineToContract(machineRecord),
|
Machine: machineToContract(machineRecord),
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) {
|
||||||
HostDevName: "fctap7",
|
HostDevName: "fctap7",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
VsockOverride: &VsockOverride{UDSPath: "/run/microagent-personalizer.vsock"},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("put snapshot load: %v", err)
|
t.Fatalf("put snapshot load: %v", err)
|
||||||
|
|
@ -47,7 +48,7 @@ func TestPutSnapshotLoadIncludesNetworkOverrides(t *testing.T) {
|
||||||
t.Fatalf("request path mismatch: got %q want %q", gotPath, "/snapshot/load")
|
t.Fatalf("request path mismatch: got %q want %q", gotPath, "/snapshot/load")
|
||||||
}
|
}
|
||||||
|
|
||||||
want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}]}"
|
want := "{\"snapshot_path\":\"vmstate.bin\",\"mem_backend\":{\"backend_type\":\"File\",\"backend_path\":\"memory.bin\"},\"resume_vm\":false,\"network_overrides\":[{\"iface_id\":\"net0\",\"host_dev_name\":\"fctap7\"}],\"vsock_override\":{\"uds_path\":\"/run/microagent-personalizer.vsock\"}}"
|
||||||
if gotBody != want {
|
if gotBody != want {
|
||||||
t.Fatalf("request body mismatch:\n got: %s\nwant: %s", gotBody, want)
|
t.Fatalf("request body mismatch:\n got: %s\nwant: %s", gotBody, want)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -168,6 +168,74 @@ func TestConfigureMachineConfiguresMMDSBeforeStart(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigureMachineConfiguresVsockBeforeStart(t *testing.T) {
|
||||||
|
var requests []capturedRequest
|
||||||
|
|
||||||
|
socketPath, shutdown := startUnixSocketServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read request body: %v", err)
|
||||||
|
}
|
||||||
|
requests = append(requests, capturedRequest{
|
||||||
|
Method: r.Method,
|
||||||
|
Path: r.URL.Path,
|
||||||
|
Body: string(body),
|
||||||
|
})
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
|
client := newAPIClient(socketPath)
|
||||||
|
spec := MachineSpec{
|
||||||
|
ID: "vm-3",
|
||||||
|
VCPUs: 1,
|
||||||
|
MemoryMiB: 512,
|
||||||
|
KernelImagePath: "/kernel",
|
||||||
|
RootFSPath: "/rootfs",
|
||||||
|
Vsock: &VsockSpec{
|
||||||
|
ID: "microagent-personalizer",
|
||||||
|
CID: 42,
|
||||||
|
Path: "/run/microagent-personalizer.vsock",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
paths := machinePaths{JailedSerialLogPath: "/logs/serial.log"}
|
||||||
|
network := NetworkAllocation{
|
||||||
|
InterfaceID: defaultInterfaceID,
|
||||||
|
TapName: "fctap0",
|
||||||
|
GuestMAC: "06:00:ac:10:00:02",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := configureMachine(context.Background(), client, paths, spec, network); err != nil {
|
||||||
|
t.Fatalf("configure machine: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotPaths := make([]string, 0, len(requests))
|
||||||
|
for _, request := range requests {
|
||||||
|
gotPaths = append(gotPaths, request.Path)
|
||||||
|
}
|
||||||
|
wantPaths := []string{
|
||||||
|
"/machine-config",
|
||||||
|
"/boot-source",
|
||||||
|
"/drives/root_drive",
|
||||||
|
"/network-interfaces/net0",
|
||||||
|
"/entropy",
|
||||||
|
"/serial",
|
||||||
|
"/vsock",
|
||||||
|
"/actions",
|
||||||
|
}
|
||||||
|
if len(gotPaths) != len(wantPaths) {
|
||||||
|
t.Fatalf("request count mismatch: got %d want %d (%v)", len(gotPaths), len(wantPaths), gotPaths)
|
||||||
|
}
|
||||||
|
for i := range wantPaths {
|
||||||
|
if gotPaths[i] != wantPaths[i] {
|
||||||
|
t.Fatalf("request %d mismatch: got %q want %q", i, gotPaths[i], wantPaths[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if requests[6].Body != "{\"guest_cid\":42,\"uds_path\":\"/run/microagent-personalizer.vsock\"}" {
|
||||||
|
t.Fatalf("vsock body mismatch: got %q", requests[6].Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) {
|
func startUnixSocketServer(t *testing.T, handler http.HandlerFunc) (string, func()) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ func stageMachineFiles(spec MachineSpec, paths machinePaths) (MachineSpec, error
|
||||||
|
|
||||||
if spec.Vsock != nil {
|
if spec.Vsock != nil {
|
||||||
vsock := *spec.Vsock
|
vsock := *spec.Vsock
|
||||||
vsock.Path = jailedVSockPath(spec)
|
vsock.Path = jailedVSockDevicePath(*spec.Vsock)
|
||||||
staged.Vsock = &vsock
|
staged.Vsock = &vsock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -244,11 +244,8 @@ func waitForPIDFile(ctx context.Context, pidFilePath string) (int, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func jailedVSockPath(spec MachineSpec) string {
|
func jailedVSockDevicePath(spec VsockSpec) string {
|
||||||
if spec.Vsock == nil {
|
return path.Join(defaultVSockRunDir, filepath.Base(strings.TrimSpace(spec.Path)))
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return path.Join(defaultVSockRunDir, filepath.Base(strings.TrimSpace(spec.Vsock.Path)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func linkMachineFile(source string, target string) error {
|
func linkMachineFile(source string, target string) error {
|
||||||
|
|
|
||||||
|
|
@ -331,6 +331,11 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vsockOverride *VsockOverride
|
||||||
|
if loadSpec.Vsock != nil {
|
||||||
|
vsockOverride = &VsockOverride{UDSPath: jailedVSockDevicePath(*loadSpec.Vsock)}
|
||||||
|
}
|
||||||
|
|
||||||
// Load snapshot (replaces the full configure+start sequence)
|
// Load snapshot (replaces the full configure+start sequence)
|
||||||
if err := client.PutSnapshotLoad(ctx, SnapshotLoadParams{
|
if err := client.PutSnapshotLoad(ctx, SnapshotLoadParams{
|
||||||
SnapshotPath: chrootStatePath,
|
SnapshotPath: chrootStatePath,
|
||||||
|
|
@ -345,6 +350,7 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us
|
||||||
HostDevName: network.TapName,
|
HostDevName: network.TapName,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
VsockOverride: vsockOverride,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
cleanup(network, paths, command, firecrackerPID)
|
cleanup(network, paths, command, firecrackerPID)
|
||||||
return nil, fmt.Errorf("load snapshot: %w", err)
|
return nil, fmt.Errorf("load snapshot: %w", err)
|
||||||
|
|
@ -369,6 +375,11 @@ func (r *Runtime) RestoreBoot(ctx context.Context, loadSpec SnapshotLoadSpec, us
|
||||||
return &state, nil
|
return &state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Runtime) PutMMDS(ctx context.Context, state MachineState, data any) error {
|
||||||
|
client := newAPIClient(state.SocketPath)
|
||||||
|
return client.PutMMDS(ctx, data)
|
||||||
|
}
|
||||||
|
|
||||||
func processExists(pid int) bool {
|
func processExists(pid int) bool {
|
||||||
if pid < 1 {
|
if pid < 1 {
|
||||||
return false
|
return false
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue