mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 05:02:05 +00:00
fix: address gateway review findings
This commit is contained in:
parent
59d3290bb9
commit
500354cd9b
14 changed files with 441 additions and 66 deletions
|
|
@ -3,21 +3,23 @@ package host
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
type Machine struct {
|
type Machine struct {
|
||||||
ID MachineID `json:"id"`
|
ID MachineID `json:"id"`
|
||||||
Artifact ArtifactRef `json:"artifact"`
|
Artifact ArtifactRef `json:"artifact"`
|
||||||
SystemVolumeID VolumeID `json:"system_volume_id,omitempty"`
|
SystemVolumeID VolumeID `json:"system_volume_id,omitempty"`
|
||||||
UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"`
|
UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"`
|
||||||
RuntimeHost string `json:"runtime_host,omitempty"`
|
RuntimeHost string `json:"runtime_host,omitempty"`
|
||||||
Ports []MachinePort `json:"ports,omitempty"`
|
Ports []MachinePort `json:"ports,omitempty"`
|
||||||
Phase MachinePhase `json:"phase"`
|
GuestSSHPublicKey string `json:"guest_ssh_host_public_key,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Phase MachinePhase `json:"phase"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
Error string `json:"error,omitempty"`
|
||||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GuestConfig struct {
|
type GuestConfig struct {
|
||||||
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
|
||||||
LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"`
|
TrustedUserCAKeys []string `json:"trusted_user_ca_keys,omitempty"`
|
||||||
|
LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GuestLoginWebhook struct {
|
type GuestLoginWebhook struct {
|
||||||
|
|
|
||||||
|
|
@ -16,5 +16,6 @@ const (
|
||||||
type MachinePort struct {
|
type MachinePort struct {
|
||||||
Name MachinePortName `json:"name"`
|
Name MachinePortName `json:"name"`
|
||||||
Port uint16 `json:"port"`
|
Port uint16 `json:"port"`
|
||||||
|
HostPort uint16 `json:"host_port,omitempty"`
|
||||||
Protocol PortProtocol `json:"protocol"`
|
Protocol PortProtocol `json:"protocol"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ type Config struct {
|
||||||
EgressInterface string
|
EgressInterface string
|
||||||
FirecrackerBinaryPath string
|
FirecrackerBinaryPath string
|
||||||
JailerBinaryPath string
|
JailerBinaryPath string
|
||||||
|
GuestLoginCAPublicKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load loads and validates the firecracker-host daemon configuration from the environment.
|
// Load loads and validates the firecracker-host daemon configuration from the environment.
|
||||||
|
|
@ -43,6 +44,7 @@ func Load() (Config, error) {
|
||||||
EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")),
|
EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")),
|
||||||
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
|
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
|
||||||
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
|
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
|
||||||
|
GuestLoginCAPublicKey: strings.TrimSpace(os.Getenv("GUEST_LOGIN_CA_PUBLIC_KEY")),
|
||||||
}
|
}
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
_ = d.runtime.Delete(context.Background(), *state)
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||||
|
if err != nil {
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
systemVolumeRecord := model.VolumeRecord{
|
systemVolumeRecord := model.VolumeRecord{
|
||||||
|
|
@ -138,19 +143,44 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
}
|
}
|
||||||
|
|
||||||
record := model.MachineRecord{
|
record := model.MachineRecord{
|
||||||
ID: req.MachineID,
|
ID: req.MachineID,
|
||||||
Artifact: req.Artifact,
|
Artifact: req.Artifact,
|
||||||
SystemVolumeID: systemVolumeRecord.ID,
|
SystemVolumeID: systemVolumeRecord.ID,
|
||||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||||
RuntimeHost: state.RuntimeHost,
|
RuntimeHost: state.RuntimeHost,
|
||||||
TapDevice: state.TapName,
|
TapDevice: state.TapName,
|
||||||
Ports: ports,
|
Ports: ports,
|
||||||
Phase: contracthost.MachinePhaseRunning,
|
GuestSSHPublicKey: guestSSHPublicKey,
|
||||||
PID: state.PID,
|
Phase: contracthost.MachinePhaseRunning,
|
||||||
SocketPath: state.SocketPath,
|
PID: state.PID,
|
||||||
CreatedAt: now,
|
SocketPath: state.SocketPath,
|
||||||
StartedAt: state.StartedAt,
|
CreatedAt: now,
|
||||||
|
StartedAt: state.StartedAt,
|
||||||
}
|
}
|
||||||
|
d.relayAllocMu.Lock()
|
||||||
|
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||||
|
var vncRelayPort uint16
|
||||||
|
if err == nil {
|
||||||
|
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||||
|
}
|
||||||
|
d.relayAllocMu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
|
for _, volume := range userVolumes {
|
||||||
|
volume.AttachedMachineID = nil
|
||||||
|
_ = d.store.UpdateVolume(context.Background(), volume)
|
||||||
|
}
|
||||||
|
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||||
|
startedRelays := true
|
||||||
|
defer func() {
|
||||||
|
if startedRelays {
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if err := d.store.CreateMachine(ctx, record); err != nil {
|
if err := d.store.CreateMachine(ctx, record); err != nil {
|
||||||
for _, volume := range userVolumes {
|
for _, volume := range userVolumes {
|
||||||
volume.AttachedMachineID = nil
|
volume.AttachedMachineID = nil
|
||||||
|
|
@ -162,6 +192,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
||||||
}
|
}
|
||||||
|
|
||||||
removeSystemVolumeOnFailure = false
|
removeSystemVolumeOnFailure = false
|
||||||
|
startedRelays = false
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
|
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,11 +42,15 @@ type Daemon struct {
|
||||||
runtime Runtime
|
runtime Runtime
|
||||||
|
|
||||||
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error
|
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error
|
||||||
|
readGuestSSHPublicKey func(context.Context, string) (string, error)
|
||||||
|
|
||||||
locksMu sync.Mutex
|
locksMu sync.Mutex
|
||||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||||
artifactLocks map[string]*sync.Mutex
|
artifactLocks map[string]*sync.Mutex
|
||||||
|
|
||||||
|
relayAllocMu sync.Mutex
|
||||||
|
machineRelaysMu sync.Mutex
|
||||||
|
machineRelayListeners map[string]net.Listener
|
||||||
publishedPortAllocMu sync.Mutex
|
publishedPortAllocMu sync.Mutex
|
||||||
publishedPortsMu sync.Mutex
|
publishedPortsMu sync.Mutex
|
||||||
publishedPortListeners map[contracthost.PublishedPortID]net.Listener
|
publishedPortListeners map[contracthost.PublishedPortID]net.Listener
|
||||||
|
|
@ -72,11 +76,14 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
||||||
store: store,
|
store: store,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
reconfigureGuestIdentity: nil,
|
reconfigureGuestIdentity: nil,
|
||||||
|
readGuestSSHPublicKey: nil,
|
||||||
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
||||||
artifactLocks: make(map[string]*sync.Mutex),
|
artifactLocks: make(map[string]*sync.Mutex),
|
||||||
|
machineRelayListeners: make(map[string]net.Listener),
|
||||||
publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener),
|
publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener),
|
||||||
}
|
}
|
||||||
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
||||||
|
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
|
||||||
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
kernelPayload := []byte("kernel-image")
|
kernelPayload := []byte("kernel-image")
|
||||||
rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4")
|
rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4")
|
||||||
|
|
@ -261,6 +262,7 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
|
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
|
||||||
|
|
||||||
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
||||||
|
|
@ -367,6 +369,7 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
var reconfiguredHost string
|
var reconfiguredHost string
|
||||||
var reconfiguredMachine contracthost.MachineID
|
var reconfiguredMachine contracthost.MachineID
|
||||||
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID) error {
|
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID) error {
|
||||||
|
|
@ -508,6 +511,7 @@ func TestDeleteMachineMissingIsNoOp(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil {
|
if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil {
|
||||||
t.Fatalf("delete missing machine: %v", err)
|
t.Fatalf("delete missing machine: %v", err)
|
||||||
|
|
@ -533,6 +537,12 @@ func testConfig(root string) appconfig.Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
|
||||||
|
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
||||||
|
return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func listenTestPort(t *testing.T, port int) net.Listener {
|
func listenTestPort(t *testing.T, port int) net.Listener {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -187,9 +187,13 @@ func isZeroChunk(chunk []byte) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultMachinePorts() []contracthost.MachinePort {
|
func defaultMachinePorts() []contracthost.MachinePort {
|
||||||
|
return buildMachinePorts(0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMachinePorts(sshRelayPort, vncRelayPort uint16) []contracthost.MachinePort {
|
||||||
return []contracthost.MachinePort{
|
return []contracthost.MachinePort{
|
||||||
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, Protocol: contracthost.PortProtocolTCP},
|
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP},
|
||||||
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, Protocol: contracthost.PortProtocolTCP},
|
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -247,7 +251,14 @@ func (d *Daemon) mergedGuestConfig(config *contracthost.GuestConfig) (*contracth
|
||||||
}
|
}
|
||||||
|
|
||||||
merged := &contracthost.GuestConfig{
|
merged := &contracthost.GuestConfig{
|
||||||
AuthorizedKeys: authorizedKeys,
|
AuthorizedKeys: authorizedKeys,
|
||||||
|
TrustedUserCAKeys: nil,
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(d.config.GuestLoginCAPublicKey) != "" {
|
||||||
|
merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, d.config.GuestLoginCAPublicKey)
|
||||||
|
}
|
||||||
|
if config != nil {
|
||||||
|
merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, config.TrustedUserCAKeys...)
|
||||||
}
|
}
|
||||||
if config != nil && config.LoginWebhook != nil {
|
if config != nil && config.LoginWebhook != nil {
|
||||||
loginWebhook := *config.LoginWebhook
|
loginWebhook := *config.LoginWebhook
|
||||||
|
|
@ -260,7 +271,7 @@ func hasGuestConfig(config *contracthost.GuestConfig) bool {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return len(config.AuthorizedKeys) > 0 || config.LoginWebhook != nil
|
return len(config.AuthorizedKeys) > 0 || len(config.TrustedUserCAKeys) > 0 || config.LoginWebhook != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error {
|
func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error {
|
||||||
|
|
@ -286,6 +297,17 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(config.TrustedUserCAKeys) > 0 {
|
||||||
|
trustedCAPath := filepath.Join(stagingDir, "trusted_user_ca_keys")
|
||||||
|
payload := []byte(strings.Join(config.TrustedUserCAKeys, "\n") + "\n")
|
||||||
|
if err := os.WriteFile(trustedCAPath, payload, 0o644); err != nil {
|
||||||
|
return fmt.Errorf("write trusted_user_ca_keys staging file: %w", err)
|
||||||
|
}
|
||||||
|
if err := replaceExt4File(ctx, imagePath, trustedCAPath, "/etc/microagent/trusted_user_ca_keys"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if config.LoginWebhook != nil {
|
if config.LoginWebhook != nil {
|
||||||
guestConfigPath := filepath.Join(stagingDir, "guest-config.json")
|
guestConfigPath := filepath.Join(stagingDir, "guest-config.json")
|
||||||
payload, err := json.Marshal(config)
|
payload, err := json.Marshal(config)
|
||||||
|
|
@ -363,16 +385,17 @@ func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID {
|
||||||
|
|
||||||
func machineToContract(record model.MachineRecord) contracthost.Machine {
|
func machineToContract(record model.MachineRecord) contracthost.Machine {
|
||||||
return contracthost.Machine{
|
return contracthost.Machine{
|
||||||
ID: record.ID,
|
ID: record.ID,
|
||||||
Artifact: record.Artifact,
|
Artifact: record.Artifact,
|
||||||
SystemVolumeID: record.SystemVolumeID,
|
SystemVolumeID: record.SystemVolumeID,
|
||||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
|
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
|
||||||
RuntimeHost: record.RuntimeHost,
|
RuntimeHost: record.RuntimeHost,
|
||||||
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
|
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
|
||||||
Phase: record.Phase,
|
GuestSSHPublicKey: record.GuestSSHPublicKey,
|
||||||
Error: record.Error,
|
Phase: record.Phase,
|
||||||
CreatedAt: record.CreatedAt,
|
Error: record.Error,
|
||||||
StartedAt: record.StartedAt,
|
CreatedAt: record.CreatedAt,
|
||||||
|
StartedAt: record.StartedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -427,6 +450,11 @@ func validateGuestConfig(config *contracthost.GuestConfig) error {
|
||||||
return fmt.Errorf("guest_config.authorized_keys[%d] is required", i)
|
return fmt.Errorf("guest_config.authorized_keys[%d] is required", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for i, key := range config.TrustedUserCAKeys {
|
||||||
|
if strings.TrimSpace(key) == "" {
|
||||||
|
return fmt.Errorf("guest_config.trusted_user_ca_keys[%d] is required", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
if config.LoginWebhook != nil {
|
if config.LoginWebhook != nil {
|
||||||
if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil {
|
if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
51
internal/daemon/guest_hostkey.go
Normal file
51
internal/daemon/guest_hostkey.go
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readGuestSSHPublicKey(ctx context.Context, runtimeHost string) (string, error) {
|
||||||
|
host := strings.TrimSpace(runtimeHost)
|
||||||
|
if host == "" {
|
||||||
|
return "", fmt.Errorf("runtime host is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
targetAddr := net.JoinHostPort(host, strconv.Itoa(int(defaultSSHPort)))
|
||||||
|
netConn, err := (&net.Dialer{}).DialContext(probeCtx, "tcp", targetAddr)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("dial guest ssh for host key: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = netConn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var captured ssh.PublicKey
|
||||||
|
clientConfig := &ssh.ClientConfig{
|
||||||
|
User: "host-key-probe",
|
||||||
|
Auth: []ssh.AuthMethod{ssh.Password("invalid")},
|
||||||
|
HostKeyAlgorithms: []string{ssh.KeyAlgoED25519},
|
||||||
|
HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error {
|
||||||
|
captured = key
|
||||||
|
return fmt.Errorf("guest ssh host key captured")
|
||||||
|
},
|
||||||
|
Timeout: defaultGuestDialTimeout,
|
||||||
|
ClientVersion: "SSH-2.0-agentcomputer-firecracker-host",
|
||||||
|
}
|
||||||
|
_, _, _, err = ssh.NewClientConn(netConn, targetAddr, clientConfig)
|
||||||
|
if captured == nil {
|
||||||
|
if err == nil {
|
||||||
|
return "", fmt.Errorf("guest ssh host key probe returned without a host key")
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("handshake guest ssh for host key: %w", err)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(captured))), nil
|
||||||
|
}
|
||||||
|
|
@ -99,10 +99,16 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
||||||
_ = d.runtime.Delete(context.Background(), *state)
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||||
|
if err != nil {
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
record.RuntimeHost = state.RuntimeHost
|
record.RuntimeHost = state.RuntimeHost
|
||||||
record.TapDevice = state.TapName
|
record.TapDevice = state.TapName
|
||||||
record.Ports = ports
|
record.Ports = ports
|
||||||
|
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||||
record.Phase = contracthost.MachinePhaseRunning
|
record.Phase = contracthost.MachinePhaseRunning
|
||||||
record.Error = ""
|
record.Error = ""
|
||||||
record.PID = state.PID
|
record.PID = state.PID
|
||||||
|
|
@ -112,7 +118,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
||||||
_ = d.runtime.Delete(context.Background(), *state)
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||||
|
d.stopMachineRelays(id)
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||||
|
d.stopMachineRelays(id)
|
||||||
d.stopPublishedPortsForMachine(id)
|
d.stopPublishedPortsForMachine(id)
|
||||||
_ = d.runtime.Delete(context.Background(), *state)
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -238,10 +250,14 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if reconciled.Phase == contracthost.MachinePhaseRunning {
|
if reconciled.Phase == contracthost.MachinePhaseRunning {
|
||||||
|
if err := d.ensureMachineRelays(ctx, reconciled); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := d.ensurePublishedPortsForMachine(ctx, *reconciled); err != nil {
|
if err := d.ensurePublishedPortsForMachine(ctx, *reconciled); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
d.stopMachineRelays(reconciled.ID)
|
||||||
d.stopPublishedPortsForMachine(reconciled.ID)
|
d.stopPublishedPortsForMachine(reconciled.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -322,6 +338,9 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if record.Phase == contracthost.MachinePhaseRunning {
|
if record.Phase == contracthost.MachinePhaseRunning {
|
||||||
|
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -375,12 +394,16 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if state.Phase == firecracker.PhaseRunning {
|
if state.Phase == firecracker.PhaseRunning {
|
||||||
|
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.runtime.Delete(ctx, *state); err != nil {
|
if err := d.runtime.Delete(ctx, *state); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
d.stopPublishedPortsForMachine(record.ID)
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
record.Phase = contracthost.MachinePhaseFailed
|
record.Phase = contracthost.MachinePhaseFailed
|
||||||
record.Error = state.Error
|
record.Error = state.Error
|
||||||
|
|
@ -396,6 +419,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
d.stopPublishedPortsForMachine(record.ID)
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -426,6 +450,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
d.stopPublishedPortsForMachine(record.ID)
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
184
internal/daemon/machine_relays.go
Normal file
184
internal/daemon/machine_relays.go
Normal file
|
|
@ -0,0 +1,184 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
minMachineSSHRelayPort = uint16(40000)
|
||||||
|
maxMachineSSHRelayPort = uint16(44999)
|
||||||
|
minMachineVNCRelayPort = uint16(45000)
|
||||||
|
maxMachineVNCRelayPort = uint16(49999)
|
||||||
|
)
|
||||||
|
|
||||||
|
func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string {
|
||||||
|
return string(machineID) + ":" + string(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func machineRelayHostPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 {
|
||||||
|
for _, port := range record.Ports {
|
||||||
|
if port.Name == name {
|
||||||
|
return port.HostPort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func machineRelayGuestPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 {
|
||||||
|
for _, port := range record.Ports {
|
||||||
|
if port.Name == name {
|
||||||
|
return port.Port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch name {
|
||||||
|
case contracthost.MachinePortNameVNC:
|
||||||
|
return defaultVNCPort
|
||||||
|
default:
|
||||||
|
return defaultSSHPort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contracthost.MachineID, name contracthost.MachinePortName) (map[uint16]struct{}, error) {
|
||||||
|
records, err := d.store.ListMachines(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
used := make(map[uint16]struct{}, len(records))
|
||||||
|
for _, record := range records {
|
||||||
|
if record.ID == machineID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if port := machineRelayHostPort(record, name); port != 0 {
|
||||||
|
used[port] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return used, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) allocateMachineRelayProxy(
|
||||||
|
ctx context.Context,
|
||||||
|
current model.MachineRecord,
|
||||||
|
name contracthost.MachinePortName,
|
||||||
|
runtimeHost string,
|
||||||
|
guestPort uint16,
|
||||||
|
minPort uint16,
|
||||||
|
maxPort uint16,
|
||||||
|
) (uint16, error) {
|
||||||
|
existingPort := machineRelayHostPort(current, name)
|
||||||
|
if existingPort != 0 {
|
||||||
|
if err := d.startMachineRelayProxy(current.ID, name, existingPort, runtimeHost, guestPort); err == nil {
|
||||||
|
return existingPort, nil
|
||||||
|
} else if !isAddrInUseError(err) {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
used, err := d.usedMachineRelayPorts(ctx, current.ID, name)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if existingPort != 0 {
|
||||||
|
used[existingPort] = struct{}{}
|
||||||
|
}
|
||||||
|
for hostPort := minPort; hostPort <= maxPort; hostPort++ {
|
||||||
|
if _, exists := used[hostPort]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := d.startMachineRelayProxy(current.ID, name, hostPort, runtimeHost, guestPort); err != nil {
|
||||||
|
if isAddrInUseError(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return hostPort, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("no relay ports are available in range %d-%d", minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) ensureMachineRelays(ctx context.Context, record *model.MachineRecord) error {
|
||||||
|
if record == nil {
|
||||||
|
return fmt.Errorf("machine record is required")
|
||||||
|
}
|
||||||
|
if record.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(record.RuntimeHost) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.relayAllocMu.Lock()
|
||||||
|
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameSSH, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameSSH), minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||||
|
var vncRelayPort uint16
|
||||||
|
if err == nil {
|
||||||
|
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameVNC, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameVNC), minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||||
|
}
|
||||||
|
d.relayAllocMu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||||
|
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||||
|
d.stopMachineRelays(record.ID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) startMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName, hostPort uint16, runtimeHost string, guestPort uint16) error {
|
||||||
|
targetHost := strings.TrimSpace(runtimeHost)
|
||||||
|
if targetHost == "" {
|
||||||
|
return fmt.Errorf("runtime host is required for machine relay %q", machineID)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := machineRelayListenerKey(machineID, name)
|
||||||
|
|
||||||
|
d.machineRelaysMu.Lock()
|
||||||
|
if _, exists := d.machineRelayListeners[key]; exists {
|
||||||
|
d.machineRelaysMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(hostPort)))
|
||||||
|
if err != nil {
|
||||||
|
d.machineRelaysMu.Unlock()
|
||||||
|
return fmt.Errorf("listen on machine relay port %d: %w", hostPort, err)
|
||||||
|
}
|
||||||
|
d.machineRelayListeners[key] = listener
|
||||||
|
d.machineRelaysMu.Unlock()
|
||||||
|
|
||||||
|
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(guestPort)))
|
||||||
|
go serveTCPProxy(listener, targetAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) stopMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName) {
|
||||||
|
key := machineRelayListenerKey(machineID, name)
|
||||||
|
|
||||||
|
d.machineRelaysMu.Lock()
|
||||||
|
listener, ok := d.machineRelayListeners[key]
|
||||||
|
if ok {
|
||||||
|
delete(d.machineRelayListeners, key)
|
||||||
|
}
|
||||||
|
d.machineRelaysMu.Unlock()
|
||||||
|
if ok {
|
||||||
|
_ = listener.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) {
|
||||||
|
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH)
|
||||||
|
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAddrInUseError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(strings.ToLower(err.Error()), "address already in use")
|
||||||
|
}
|
||||||
|
|
@ -177,11 +177,11 @@ func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtime
|
||||||
d.publishedPortsMu.Unlock()
|
d.publishedPortsMu.Unlock()
|
||||||
|
|
||||||
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port)))
|
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port)))
|
||||||
go d.servePublishedPortProxy(port.ID, listener, targetAddr)
|
go serveTCPProxy(listener, targetAddr)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) {
|
func serveTCPProxy(listener net.Listener, targetAddr string) {
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -190,11 +190,11 @@ func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, li
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go proxyPublishedPortConnection(conn, targetAddr)
|
go proxyTCPConnection(conn, targetAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func proxyPublishedPortConnection(source net.Conn, targetAddr string) {
|
func proxyTCPConnection(source net.Conn, targetAddr string) {
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = source.Close()
|
_ = source.Close()
|
||||||
}()
|
}()
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ func TestCreatePublishedPortSerializesHostPortAllocationAcrossMachines(t *testin
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
for _, machineID := range []contracthost.MachineID{"vm-1", "vm-2"} {
|
for _, machineID := range []contracthost.MachineID{"vm-1", "vm-2"} {
|
||||||
if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{
|
if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||||
|
|
@ -170,6 +171,7 @@ func TestGetStorageReportHandlesSparseSnapshotPathsAndIncludesPublishedPortPool(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
response, err := hostDaemon.GetStorageReport(context.Background())
|
response, err := hostDaemon.GetStorageReport(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -209,6 +211,7 @@ func TestReconcileSnapshotPreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
snapshotID := contracthost.SnapshotID("snap-1")
|
snapshotID := contracthost.SnapshotID("snap-1")
|
||||||
operation := model.OperationRecord{
|
operation := model.OperationRecord{
|
||||||
|
|
@ -249,6 +252,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
operation := model.OperationRecord{
|
operation := model.OperationRecord{
|
||||||
MachineID: "vm-1",
|
MachineID: "vm-1",
|
||||||
|
|
@ -291,6 +295,7 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
machineID := contracthost.MachineID("vm-1")
|
machineID := contracthost.MachineID("vm-1")
|
||||||
snapshotID := contracthost.SnapshotID("snap-1")
|
snapshotID := contracthost.SnapshotID("snap-1")
|
||||||
|
|
@ -346,6 +351,7 @@ func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create daemon: %v", err)
|
t.Fatalf("create daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
hostDaemon.stopPublishedPortProxy("port-1")
|
hostDaemon.stopPublishedPortProxy("port-1")
|
||||||
|
|
|
||||||
|
|
@ -260,6 +260,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return nil, fmt.Errorf("reconfigure restored guest identity: %w", err)
|
return nil, fmt.Errorf("reconfigure restored guest identity: %w", err)
|
||||||
}
|
}
|
||||||
|
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost)
|
||||||
|
if err != nil {
|
||||||
|
_ = d.runtime.Delete(ctx, *machineState)
|
||||||
|
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
||||||
|
clearOperation = true
|
||||||
|
return nil, fmt.Errorf("read restored guest ssh host key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
systemVolumeID := d.systemVolumeID(req.MachineID)
|
systemVolumeID := d.systemVolumeID(req.MachineID)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
@ -277,22 +284,42 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
||||||
}
|
}
|
||||||
|
|
||||||
machineRecord := model.MachineRecord{
|
machineRecord := model.MachineRecord{
|
||||||
ID: req.MachineID,
|
ID: req.MachineID,
|
||||||
Artifact: snap.Artifact,
|
Artifact: snap.Artifact,
|
||||||
SystemVolumeID: systemVolumeID,
|
SystemVolumeID: systemVolumeID,
|
||||||
RuntimeHost: machineState.RuntimeHost,
|
RuntimeHost: machineState.RuntimeHost,
|
||||||
TapDevice: machineState.TapName,
|
TapDevice: machineState.TapName,
|
||||||
Ports: defaultMachinePorts(),
|
Ports: defaultMachinePorts(),
|
||||||
Phase: contracthost.MachinePhaseRunning,
|
GuestSSHPublicKey: guestSSHPublicKey,
|
||||||
PID: machineState.PID,
|
Phase: contracthost.MachinePhaseRunning,
|
||||||
SocketPath: machineState.SocketPath,
|
PID: machineState.PID,
|
||||||
CreatedAt: now,
|
SocketPath: machineState.SocketPath,
|
||||||
StartedAt: machineState.StartedAt,
|
CreatedAt: now,
|
||||||
|
StartedAt: machineState.StartedAt,
|
||||||
}
|
}
|
||||||
|
d.relayAllocMu.Lock()
|
||||||
|
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||||
|
var vncRelayPort uint16
|
||||||
|
if err == nil {
|
||||||
|
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||||
|
}
|
||||||
|
d.relayAllocMu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
d.stopMachineRelays(machineRecord.ID)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||||
|
startedRelays := true
|
||||||
|
defer func() {
|
||||||
|
if startedRelays {
|
||||||
|
d.stopMachineRelays(machineRecord.ID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
|
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
startedRelays = false
|
||||||
clearOperation = true
|
clearOperation = true
|
||||||
return &contracthost.RestoreSnapshotResponse{
|
return &contracthost.RestoreSnapshotResponse{
|
||||||
Machine: machineToContract(machineRecord),
|
Machine: machineToContract(machineRecord),
|
||||||
|
|
|
||||||
|
|
@ -27,19 +27,20 @@ type ArtifactRecord struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type MachineRecord struct {
|
type MachineRecord struct {
|
||||||
ID contracthost.MachineID
|
ID contracthost.MachineID
|
||||||
Artifact contracthost.ArtifactRef
|
Artifact contracthost.ArtifactRef
|
||||||
SystemVolumeID contracthost.VolumeID
|
SystemVolumeID contracthost.VolumeID
|
||||||
UserVolumeIDs []contracthost.VolumeID
|
UserVolumeIDs []contracthost.VolumeID
|
||||||
RuntimeHost string
|
RuntimeHost string
|
||||||
TapDevice string
|
TapDevice string
|
||||||
Ports []contracthost.MachinePort
|
Ports []contracthost.MachinePort
|
||||||
Phase contracthost.MachinePhase
|
GuestSSHPublicKey string
|
||||||
Error string
|
Phase contracthost.MachinePhase
|
||||||
PID int
|
Error string
|
||||||
SocketPath string
|
PID int
|
||||||
CreatedAt time.Time
|
SocketPath string
|
||||||
StartedAt *time.Time
|
CreatedAt time.Time
|
||||||
|
StartedAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type VolumeRecord struct {
|
type VolumeRecord struct {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue