mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 03:00:42 +00:00
fix: address gateway review findings
This commit is contained in:
parent
59d3290bb9
commit
500354cd9b
14 changed files with 441 additions and 66 deletions
|
|
@ -25,6 +25,7 @@ type Config struct {
|
|||
EgressInterface string
|
||||
FirecrackerBinaryPath string
|
||||
JailerBinaryPath string
|
||||
GuestLoginCAPublicKey string
|
||||
}
|
||||
|
||||
// Load loads and validates the firecracker-host daemon configuration from the environment.
|
||||
|
|
@ -43,6 +44,7 @@ func Load() (Config, error) {
|
|||
EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")),
|
||||
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
|
||||
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
|
||||
GuestLoginCAPublicKey: strings.TrimSpace(os.Getenv("GUEST_LOGIN_CA_PUBLIC_KEY")),
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return Config{}, err
|
||||
|
|
|
|||
|
|
@ -103,6 +103,11 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||
if err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
systemVolumeRecord := model.VolumeRecord{
|
||||
|
|
@ -138,19 +143,44 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
}
|
||||
|
||||
record := model.MachineRecord{
|
||||
ID: req.MachineID,
|
||||
Artifact: req.Artifact,
|
||||
SystemVolumeID: systemVolumeRecord.ID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||
RuntimeHost: state.RuntimeHost,
|
||||
TapDevice: state.TapName,
|
||||
Ports: ports,
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: state.PID,
|
||||
SocketPath: state.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: state.StartedAt,
|
||||
ID: req.MachineID,
|
||||
Artifact: req.Artifact,
|
||||
SystemVolumeID: systemVolumeRecord.ID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
|
||||
RuntimeHost: state.RuntimeHost,
|
||||
TapDevice: state.TapName,
|
||||
Ports: ports,
|
||||
GuestSSHPublicKey: guestSSHPublicKey,
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: state.PID,
|
||||
SocketPath: state.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: state.StartedAt,
|
||||
}
|
||||
d.relayAllocMu.Lock()
|
||||
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameSSH, record.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||
var vncRelayPort uint16
|
||||
if err == nil {
|
||||
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, record, contracthost.MachinePortNameVNC, record.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||
}
|
||||
d.relayAllocMu.Unlock()
|
||||
if err != nil {
|
||||
d.stopMachineRelays(record.ID)
|
||||
for _, volume := range userVolumes {
|
||||
volume.AttachedMachineID = nil
|
||||
_ = d.store.UpdateVolume(context.Background(), volume)
|
||||
}
|
||||
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||
startedRelays := true
|
||||
defer func() {
|
||||
if startedRelays {
|
||||
d.stopMachineRelays(record.ID)
|
||||
}
|
||||
}()
|
||||
if err := d.store.CreateMachine(ctx, record); err != nil {
|
||||
for _, volume := range userVolumes {
|
||||
volume.AttachedMachineID = nil
|
||||
|
|
@ -162,6 +192,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
}
|
||||
|
||||
removeSystemVolumeOnFailure = false
|
||||
startedRelays = false
|
||||
clearOperation = true
|
||||
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,11 +42,15 @@ type Daemon struct {
|
|||
runtime Runtime
|
||||
|
||||
reconfigureGuestIdentity func(context.Context, string, contracthost.MachineID) error
|
||||
readGuestSSHPublicKey func(context.Context, string) (string, error)
|
||||
|
||||
locksMu sync.Mutex
|
||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||
artifactLocks map[string]*sync.Mutex
|
||||
|
||||
relayAllocMu sync.Mutex
|
||||
machineRelaysMu sync.Mutex
|
||||
machineRelayListeners map[string]net.Listener
|
||||
publishedPortAllocMu sync.Mutex
|
||||
publishedPortsMu sync.Mutex
|
||||
publishedPortListeners map[contracthost.PublishedPortID]net.Listener
|
||||
|
|
@ -72,11 +76,14 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
|||
store: store,
|
||||
runtime: runtime,
|
||||
reconfigureGuestIdentity: nil,
|
||||
readGuestSSHPublicKey: nil,
|
||||
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
||||
artifactLocks: make(map[string]*sync.Mutex),
|
||||
machineRelayListeners: make(map[string]net.Listener),
|
||||
publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener),
|
||||
}
|
||||
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
||||
daemon.readGuestSSHPublicKey = readGuestSSHPublicKey
|
||||
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
kernelPayload := []byte("kernel-image")
|
||||
rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4")
|
||||
|
|
@ -261,6 +262,7 @@ func TestRestoreSnapshotRejectsRunningSourceMachine(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
hostDaemon.reconfigureGuestIdentity = func(context.Context, string, contracthost.MachineID) error { return nil }
|
||||
|
||||
artifactRef := contracthost.ArtifactRef{KernelImageURL: "kernel", RootFSURL: "rootfs"}
|
||||
|
|
@ -367,6 +369,7 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
var reconfiguredHost string
|
||||
var reconfiguredMachine contracthost.MachineID
|
||||
hostDaemon.reconfigureGuestIdentity = func(_ context.Context, host string, machineID contracthost.MachineID) error {
|
||||
|
|
@ -508,6 +511,7 @@ func TestDeleteMachineMissingIsNoOp(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil {
|
||||
t.Fatalf("delete missing machine: %v", err)
|
||||
|
|
@ -533,6 +537,12 @@ func testConfig(root string) appconfig.Config {
|
|||
}
|
||||
}
|
||||
|
||||
func stubGuestSSHPublicKeyReader(hostDaemon *Daemon) {
|
||||
hostDaemon.readGuestSSHPublicKey = func(context.Context, string) (string, error) {
|
||||
return "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO0j1AyW0mQm9a1G2rY0R4fP2G5+4Qx2V3FJ9P2mA6N3", nil
|
||||
}
|
||||
}
|
||||
|
||||
func listenTestPort(t *testing.T, port int) net.Listener {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
|
|
@ -187,9 +187,13 @@ func isZeroChunk(chunk []byte) bool {
|
|||
}
|
||||
|
||||
func defaultMachinePorts() []contracthost.MachinePort {
|
||||
return buildMachinePorts(0, 0)
|
||||
}
|
||||
|
||||
func buildMachinePorts(sshRelayPort, vncRelayPort uint16) []contracthost.MachinePort {
|
||||
return []contracthost.MachinePort{
|
||||
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, Protocol: contracthost.PortProtocolTCP},
|
||||
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, Protocol: contracthost.PortProtocolTCP},
|
||||
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP},
|
||||
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -247,7 +251,14 @@ func (d *Daemon) mergedGuestConfig(config *contracthost.GuestConfig) (*contracth
|
|||
}
|
||||
|
||||
merged := &contracthost.GuestConfig{
|
||||
AuthorizedKeys: authorizedKeys,
|
||||
AuthorizedKeys: authorizedKeys,
|
||||
TrustedUserCAKeys: nil,
|
||||
}
|
||||
if strings.TrimSpace(d.config.GuestLoginCAPublicKey) != "" {
|
||||
merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, d.config.GuestLoginCAPublicKey)
|
||||
}
|
||||
if config != nil {
|
||||
merged.TrustedUserCAKeys = append(merged.TrustedUserCAKeys, config.TrustedUserCAKeys...)
|
||||
}
|
||||
if config != nil && config.LoginWebhook != nil {
|
||||
loginWebhook := *config.LoginWebhook
|
||||
|
|
@ -260,7 +271,7 @@ func hasGuestConfig(config *contracthost.GuestConfig) bool {
|
|||
if config == nil {
|
||||
return false
|
||||
}
|
||||
return len(config.AuthorizedKeys) > 0 || config.LoginWebhook != nil
|
||||
return len(config.AuthorizedKeys) > 0 || len(config.TrustedUserCAKeys) > 0 || config.LoginWebhook != nil
|
||||
}
|
||||
|
||||
func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error {
|
||||
|
|
@ -286,6 +297,17 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
|
|||
}
|
||||
}
|
||||
|
||||
if len(config.TrustedUserCAKeys) > 0 {
|
||||
trustedCAPath := filepath.Join(stagingDir, "trusted_user_ca_keys")
|
||||
payload := []byte(strings.Join(config.TrustedUserCAKeys, "\n") + "\n")
|
||||
if err := os.WriteFile(trustedCAPath, payload, 0o644); err != nil {
|
||||
return fmt.Errorf("write trusted_user_ca_keys staging file: %w", err)
|
||||
}
|
||||
if err := replaceExt4File(ctx, imagePath, trustedCAPath, "/etc/microagent/trusted_user_ca_keys"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if config.LoginWebhook != nil {
|
||||
guestConfigPath := filepath.Join(stagingDir, "guest-config.json")
|
||||
payload, err := json.Marshal(config)
|
||||
|
|
@ -363,16 +385,17 @@ func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID {
|
|||
|
||||
func machineToContract(record model.MachineRecord) contracthost.Machine {
|
||||
return contracthost.Machine{
|
||||
ID: record.ID,
|
||||
Artifact: record.Artifact,
|
||||
SystemVolumeID: record.SystemVolumeID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
|
||||
RuntimeHost: record.RuntimeHost,
|
||||
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
|
||||
Phase: record.Phase,
|
||||
Error: record.Error,
|
||||
CreatedAt: record.CreatedAt,
|
||||
StartedAt: record.StartedAt,
|
||||
ID: record.ID,
|
||||
Artifact: record.Artifact,
|
||||
SystemVolumeID: record.SystemVolumeID,
|
||||
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
|
||||
RuntimeHost: record.RuntimeHost,
|
||||
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
|
||||
GuestSSHPublicKey: record.GuestSSHPublicKey,
|
||||
Phase: record.Phase,
|
||||
Error: record.Error,
|
||||
CreatedAt: record.CreatedAt,
|
||||
StartedAt: record.StartedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -427,6 +450,11 @@ func validateGuestConfig(config *contracthost.GuestConfig) error {
|
|||
return fmt.Errorf("guest_config.authorized_keys[%d] is required", i)
|
||||
}
|
||||
}
|
||||
for i, key := range config.TrustedUserCAKeys {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
return fmt.Errorf("guest_config.trusted_user_ca_keys[%d] is required", i)
|
||||
}
|
||||
}
|
||||
if config.LoginWebhook != nil {
|
||||
if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil {
|
||||
return err
|
||||
|
|
|
|||
51
internal/daemon/guest_hostkey.go
Normal file
51
internal/daemon/guest_hostkey.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func readGuestSSHPublicKey(ctx context.Context, runtimeHost string) (string, error) {
|
||||
host := strings.TrimSpace(runtimeHost)
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("runtime host is required")
|
||||
}
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, defaultGuestDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
targetAddr := net.JoinHostPort(host, strconv.Itoa(int(defaultSSHPort)))
|
||||
netConn, err := (&net.Dialer{}).DialContext(probeCtx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dial guest ssh for host key: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = netConn.Close()
|
||||
}()
|
||||
|
||||
var captured ssh.PublicKey
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "host-key-probe",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("invalid")},
|
||||
HostKeyAlgorithms: []string{ssh.KeyAlgoED25519},
|
||||
HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error {
|
||||
captured = key
|
||||
return fmt.Errorf("guest ssh host key captured")
|
||||
},
|
||||
Timeout: defaultGuestDialTimeout,
|
||||
ClientVersion: "SSH-2.0-agentcomputer-firecracker-host",
|
||||
}
|
||||
_, _, _, err = ssh.NewClientConn(netConn, targetAddr, clientConfig)
|
||||
if captured == nil {
|
||||
if err == nil {
|
||||
return "", fmt.Errorf("guest ssh host key probe returned without a host key")
|
||||
}
|
||||
return "", fmt.Errorf("handshake guest ssh for host key: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(captured))), nil
|
||||
}
|
||||
|
|
@ -99,10 +99,16 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, state.RuntimeHost)
|
||||
if err != nil {
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record.RuntimeHost = state.RuntimeHost
|
||||
record.TapDevice = state.TapName
|
||||
record.Ports = ports
|
||||
record.GuestSSHPublicKey = guestSSHPublicKey
|
||||
record.Phase = contracthost.MachinePhaseRunning
|
||||
record.Error = ""
|
||||
record.PID = state.PID
|
||||
|
|
@ -112,7 +118,13 @@ func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*
|
|||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
d.stopMachineRelays(id)
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||
d.stopMachineRelays(id)
|
||||
d.stopPublishedPortsForMachine(id)
|
||||
_ = d.runtime.Delete(context.Background(), *state)
|
||||
return nil, err
|
||||
|
|
@ -238,10 +250,14 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
if reconciled.Phase == contracthost.MachinePhaseRunning {
|
||||
if err := d.ensureMachineRelays(ctx, reconciled); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *reconciled); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
d.stopMachineRelays(reconciled.ID)
|
||||
d.stopPublishedPortsForMachine(reconciled.ID)
|
||||
}
|
||||
}
|
||||
|
|
@ -322,6 +338,9 @@ func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.Mach
|
|||
return err
|
||||
}
|
||||
if record.Phase == contracthost.MachinePhaseRunning {
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -375,12 +394,16 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
|||
return nil, err
|
||||
}
|
||||
if state.Phase == firecracker.PhaseRunning {
|
||||
if err := d.ensureMachineRelays(ctx, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
if err := d.runtime.Delete(ctx, *state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.stopMachineRelays(record.ID)
|
||||
d.stopPublishedPortsForMachine(record.ID)
|
||||
record.Phase = contracthost.MachinePhaseFailed
|
||||
record.Error = state.Error
|
||||
|
|
@ -396,6 +419,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
|||
}
|
||||
|
||||
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||
d.stopMachineRelays(record.ID)
|
||||
d.stopPublishedPortsForMachine(record.ID)
|
||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||
return err
|
||||
|
|
@ -426,6 +450,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
|
|||
}
|
||||
|
||||
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||
d.stopMachineRelays(record.ID)
|
||||
d.stopPublishedPortsForMachine(record.ID)
|
||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||
return err
|
||||
|
|
|
|||
184
internal/daemon/machine_relays.go
Normal file
184
internal/daemon/machine_relays.go
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||
)
|
||||
|
||||
const (
|
||||
minMachineSSHRelayPort = uint16(40000)
|
||||
maxMachineSSHRelayPort = uint16(44999)
|
||||
minMachineVNCRelayPort = uint16(45000)
|
||||
maxMachineVNCRelayPort = uint16(49999)
|
||||
)
|
||||
|
||||
func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string {
|
||||
return string(machineID) + ":" + string(name)
|
||||
}
|
||||
|
||||
func machineRelayHostPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 {
|
||||
for _, port := range record.Ports {
|
||||
if port.Name == name {
|
||||
return port.HostPort
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func machineRelayGuestPort(record model.MachineRecord, name contracthost.MachinePortName) uint16 {
|
||||
for _, port := range record.Ports {
|
||||
if port.Name == name {
|
||||
return port.Port
|
||||
}
|
||||
}
|
||||
switch name {
|
||||
case contracthost.MachinePortNameVNC:
|
||||
return defaultVNCPort
|
||||
default:
|
||||
return defaultSSHPort
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) usedMachineRelayPorts(ctx context.Context, machineID contracthost.MachineID, name contracthost.MachinePortName) (map[uint16]struct{}, error) {
|
||||
records, err := d.store.ListMachines(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
used := make(map[uint16]struct{}, len(records))
|
||||
for _, record := range records {
|
||||
if record.ID == machineID {
|
||||
continue
|
||||
}
|
||||
if port := machineRelayHostPort(record, name); port != 0 {
|
||||
used[port] = struct{}{}
|
||||
}
|
||||
}
|
||||
return used, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) allocateMachineRelayProxy(
|
||||
ctx context.Context,
|
||||
current model.MachineRecord,
|
||||
name contracthost.MachinePortName,
|
||||
runtimeHost string,
|
||||
guestPort uint16,
|
||||
minPort uint16,
|
||||
maxPort uint16,
|
||||
) (uint16, error) {
|
||||
existingPort := machineRelayHostPort(current, name)
|
||||
if existingPort != 0 {
|
||||
if err := d.startMachineRelayProxy(current.ID, name, existingPort, runtimeHost, guestPort); err == nil {
|
||||
return existingPort, nil
|
||||
} else if !isAddrInUseError(err) {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
used, err := d.usedMachineRelayPorts(ctx, current.ID, name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if existingPort != 0 {
|
||||
used[existingPort] = struct{}{}
|
||||
}
|
||||
for hostPort := minPort; hostPort <= maxPort; hostPort++ {
|
||||
if _, exists := used[hostPort]; exists {
|
||||
continue
|
||||
}
|
||||
if err := d.startMachineRelayProxy(current.ID, name, hostPort, runtimeHost, guestPort); err != nil {
|
||||
if isAddrInUseError(err) {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return hostPort, nil
|
||||
}
|
||||
return 0, fmt.Errorf("no relay ports are available in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func (d *Daemon) ensureMachineRelays(ctx context.Context, record *model.MachineRecord) error {
|
||||
if record == nil {
|
||||
return fmt.Errorf("machine record is required")
|
||||
}
|
||||
if record.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(record.RuntimeHost) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.relayAllocMu.Lock()
|
||||
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameSSH, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameSSH), minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||
var vncRelayPort uint16
|
||||
if err == nil {
|
||||
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, *record, contracthost.MachinePortNameVNC, record.RuntimeHost, machineRelayGuestPort(*record, contracthost.MachinePortNameVNC), minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||
}
|
||||
d.relayAllocMu.Unlock()
|
||||
if err != nil {
|
||||
d.stopMachineRelays(record.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||
d.stopMachineRelays(record.ID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) startMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName, hostPort uint16, runtimeHost string, guestPort uint16) error {
|
||||
targetHost := strings.TrimSpace(runtimeHost)
|
||||
if targetHost == "" {
|
||||
return fmt.Errorf("runtime host is required for machine relay %q", machineID)
|
||||
}
|
||||
|
||||
key := machineRelayListenerKey(machineID, name)
|
||||
|
||||
d.machineRelaysMu.Lock()
|
||||
if _, exists := d.machineRelayListeners[key]; exists {
|
||||
d.machineRelaysMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(hostPort)))
|
||||
if err != nil {
|
||||
d.machineRelaysMu.Unlock()
|
||||
return fmt.Errorf("listen on machine relay port %d: %w", hostPort, err)
|
||||
}
|
||||
d.machineRelayListeners[key] = listener
|
||||
d.machineRelaysMu.Unlock()
|
||||
|
||||
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(guestPort)))
|
||||
go serveTCPProxy(listener, targetAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) stopMachineRelayProxy(machineID contracthost.MachineID, name contracthost.MachinePortName) {
|
||||
key := machineRelayListenerKey(machineID, name)
|
||||
|
||||
d.machineRelaysMu.Lock()
|
||||
listener, ok := d.machineRelayListeners[key]
|
||||
if ok {
|
||||
delete(d.machineRelayListeners, key)
|
||||
}
|
||||
d.machineRelaysMu.Unlock()
|
||||
if ok {
|
||||
_ = listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) {
|
||||
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH)
|
||||
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC)
|
||||
}
|
||||
|
||||
func isAddrInUseError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "address already in use")
|
||||
}
|
||||
|
|
@ -177,11 +177,11 @@ func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtime
|
|||
d.publishedPortsMu.Unlock()
|
||||
|
||||
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port)))
|
||||
go d.servePublishedPortProxy(port.ID, listener, targetAddr)
|
||||
go serveTCPProxy(listener, targetAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) {
|
||||
func serveTCPProxy(listener net.Listener, targetAddr string) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
|
|
@ -190,11 +190,11 @@ func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, li
|
|||
}
|
||||
continue
|
||||
}
|
||||
go proxyPublishedPortConnection(conn, targetAddr)
|
||||
go proxyTCPConnection(conn, targetAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func proxyPublishedPortConnection(source net.Conn, targetAddr string) {
|
||||
func proxyTCPConnection(source net.Conn, targetAddr string) {
|
||||
defer func() {
|
||||
_ = source.Close()
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ func TestCreatePublishedPortSerializesHostPortAllocationAcrossMachines(t *testin
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
for _, machineID := range []contracthost.MachineID{"vm-1", "vm-2"} {
|
||||
if err := baseStore.CreateMachine(context.Background(), model.MachineRecord{
|
||||
|
|
@ -170,6 +171,7 @@ func TestGetStorageReportHandlesSparseSnapshotPathsAndIncludesPublishedPortPool(
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
response, err := hostDaemon.GetStorageReport(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -209,6 +211,7 @@ func TestReconcileSnapshotPreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
snapshotID := contracthost.SnapshotID("snap-1")
|
||||
operation := model.OperationRecord{
|
||||
|
|
@ -249,6 +252,7 @@ func TestReconcileRestorePreservesArtifactsOnUnexpectedStoreError(t *testing.T)
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
operation := model.OperationRecord{
|
||||
MachineID: "vm-1",
|
||||
|
|
@ -291,6 +295,7 @@ func TestCreateSnapshotRejectsDuplicateSnapshotIDWithoutTouchingExistingArtifact
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
machineID := contracthost.MachineID("vm-1")
|
||||
snapshotID := contracthost.SnapshotID("snap-1")
|
||||
|
|
@ -346,6 +351,7 @@ func TestReconcileUsesReconciledMachineStateForPublishedPorts(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
stubGuestSSHPublicKeyReader(hostDaemon)
|
||||
|
||||
t.Cleanup(func() {
|
||||
hostDaemon.stopPublishedPortProxy("port-1")
|
||||
|
|
|
|||
|
|
@ -260,6 +260,13 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
clearOperation = true
|
||||
return nil, fmt.Errorf("reconfigure restored guest identity: %w", err)
|
||||
}
|
||||
guestSSHPublicKey, err := d.readGuestSSHPublicKey(ctx, machineState.RuntimeHost)
|
||||
if err != nil {
|
||||
_ = d.runtime.Delete(ctx, *machineState)
|
||||
_ = os.RemoveAll(filepath.Dir(newSystemDiskPath))
|
||||
clearOperation = true
|
||||
return nil, fmt.Errorf("read restored guest ssh host key: %w", err)
|
||||
}
|
||||
|
||||
systemVolumeID := d.systemVolumeID(req.MachineID)
|
||||
now := time.Now().UTC()
|
||||
|
|
@ -277,22 +284,42 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
|
|||
}
|
||||
|
||||
machineRecord := model.MachineRecord{
|
||||
ID: req.MachineID,
|
||||
Artifact: snap.Artifact,
|
||||
SystemVolumeID: systemVolumeID,
|
||||
RuntimeHost: machineState.RuntimeHost,
|
||||
TapDevice: machineState.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: machineState.PID,
|
||||
SocketPath: machineState.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: machineState.StartedAt,
|
||||
ID: req.MachineID,
|
||||
Artifact: snap.Artifact,
|
||||
SystemVolumeID: systemVolumeID,
|
||||
RuntimeHost: machineState.RuntimeHost,
|
||||
TapDevice: machineState.TapName,
|
||||
Ports: defaultMachinePorts(),
|
||||
GuestSSHPublicKey: guestSSHPublicKey,
|
||||
Phase: contracthost.MachinePhaseRunning,
|
||||
PID: machineState.PID,
|
||||
SocketPath: machineState.SocketPath,
|
||||
CreatedAt: now,
|
||||
StartedAt: machineState.StartedAt,
|
||||
}
|
||||
d.relayAllocMu.Lock()
|
||||
sshRelayPort, err := d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameSSH, machineRecord.RuntimeHost, defaultSSHPort, minMachineSSHRelayPort, maxMachineSSHRelayPort)
|
||||
var vncRelayPort uint16
|
||||
if err == nil {
|
||||
vncRelayPort, err = d.allocateMachineRelayProxy(ctx, machineRecord, contracthost.MachinePortNameVNC, machineRecord.RuntimeHost, defaultVNCPort, minMachineVNCRelayPort, maxMachineVNCRelayPort)
|
||||
}
|
||||
d.relayAllocMu.Unlock()
|
||||
if err != nil {
|
||||
d.stopMachineRelays(machineRecord.ID)
|
||||
return nil, err
|
||||
}
|
||||
machineRecord.Ports = buildMachinePorts(sshRelayPort, vncRelayPort)
|
||||
startedRelays := true
|
||||
defer func() {
|
||||
if startedRelays {
|
||||
d.stopMachineRelays(machineRecord.ID)
|
||||
}
|
||||
}()
|
||||
if err := d.store.CreateMachine(ctx, machineRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startedRelays = false
|
||||
clearOperation = true
|
||||
return &contracthost.RestoreSnapshotResponse{
|
||||
Machine: machineToContract(machineRecord),
|
||||
|
|
|
|||
|
|
@ -27,19 +27,20 @@ type ArtifactRecord struct {
|
|||
}
|
||||
|
||||
type MachineRecord struct {
|
||||
ID contracthost.MachineID
|
||||
Artifact contracthost.ArtifactRef
|
||||
SystemVolumeID contracthost.VolumeID
|
||||
UserVolumeIDs []contracthost.VolumeID
|
||||
RuntimeHost string
|
||||
TapDevice string
|
||||
Ports []contracthost.MachinePort
|
||||
Phase contracthost.MachinePhase
|
||||
Error string
|
||||
PID int
|
||||
SocketPath string
|
||||
CreatedAt time.Time
|
||||
StartedAt *time.Time
|
||||
ID contracthost.MachineID
|
||||
Artifact contracthost.ArtifactRef
|
||||
SystemVolumeID contracthost.VolumeID
|
||||
UserVolumeIDs []contracthost.VolumeID
|
||||
RuntimeHost string
|
||||
TapDevice string
|
||||
Ports []contracthost.MachinePort
|
||||
GuestSSHPublicKey string
|
||||
Phase contracthost.MachinePhase
|
||||
Error string
|
||||
PID int
|
||||
SocketPath string
|
||||
CreatedAt time.Time
|
||||
StartedAt *time.Time
|
||||
}
|
||||
|
||||
type VolumeRecord struct {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue