From 6489e270cef42b978fb19569cf4c696599d046d9 Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Sun, 12 Apr 2026 22:47:51 +0000 Subject: [PATCH] feat: phase 1 --- contract/networking.go | 5 +- internal/daemon/daemon.go | 1 + internal/daemon/daemon_test.go | 6 +- internal/daemon/exec_relay.go | 48 ++++++++++++ internal/daemon/exec_relay_test.go | 85 ++++++++++++++++++++++ internal/daemon/files.go | 30 +++++++- internal/daemon/machine_relays.go | 15 ++-- internal/daemon/review_regressions_test.go | 2 +- internal/httpapi/handlers.go | 15 ++++ 9 files changed, 194 insertions(+), 13 deletions(-) create mode 100644 internal/daemon/exec_relay.go create mode 100644 internal/daemon/exec_relay_test.go diff --git a/contract/networking.go b/contract/networking.go index 3750b1a..1159a82 100644 --- a/contract/networking.go +++ b/contract/networking.go @@ -5,8 +5,9 @@ type MachinePortName string type PortProtocol string const ( - MachinePortNameSSH MachinePortName = "ssh" - MachinePortNameVNC MachinePortName = "vnc" + MachinePortNameSSH MachinePortName = "ssh" + MachinePortNameVNC MachinePortName = "vnc" + MachinePortNameExec MachinePortName = "exec" ) const ( diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 8e457d9..aeef596 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -23,6 +23,7 @@ const ( defaultGuestDiskSizeBytes = int64(10 * 1024 * 1024 * 1024) // 10 GB defaultSSHPort = uint16(2222) defaultVNCPort = uint16(6080) + defaultGuestdPort = uint16(49983) defaultCopyBufferSize = 1024 * 1024 defaultGuestDialTimeout = 500 * time.Millisecond defaultGuestStopTimeout = 10 * time.Second diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index a500ea8..e29d840 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -158,10 +158,10 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { if response.Machine.RuntimeHost != "127.0.0.1" { t.Fatalf("runtime host mismatch: got %q", response.Machine.RuntimeHost) } - if len(response.Machine.Ports) != 2 { - t.Fatalf("machine ports mismatch: got %d want 2", len(response.Machine.Ports)) + if len(response.Machine.Ports) != 3 { + t.Fatalf("machine ports mismatch: got %d want 3", len(response.Machine.Ports)) } - if response.Machine.Ports[0].Port != defaultSSHPort || response.Machine.Ports[1].Port != defaultVNCPort { + if response.Machine.Ports[0].Port != defaultSSHPort || response.Machine.Ports[1].Port != defaultVNCPort || response.Machine.Ports[2].Port != defaultGuestdPort { t.Fatalf("machine ports mismatch: got %#v", response.Machine.Ports) } if runtime.bootCalls != 1 { diff --git a/internal/daemon/exec_relay.go b/internal/daemon/exec_relay.go new file mode 100644 index 0000000..ddcd69f --- /dev/null +++ b/internal/daemon/exec_relay.go @@ -0,0 +1,48 @@ +package daemon + +import ( + "context" + "fmt" + "strings" + + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) EnsureExecRelay(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) { + unlock := d.lockMachine(id) + defer unlock() + + record, err := d.store.GetMachine(ctx, id) + if err != nil { + return nil, err + } + if record.Phase != contracthost.MachinePhaseRunning { + return nil, fmt.Errorf("machine %q is not running", id) + } + if strings.TrimSpace(record.RuntimeHost) == "" { + return nil, fmt.Errorf("machine %q runtime host is unavailable", id) + } + + d.relayAllocMu.Lock() + execRelayPort, err := d.allocateMachineRelayProxy( + ctx, + *record, + contracthost.MachinePortNameExec, + record.RuntimeHost, + defaultGuestdPort, + minMachineExecRelayPort, + maxMachineExecRelayPort, + ) + d.relayAllocMu.Unlock() + if err != nil { + d.stopMachineRelayProxy(record.ID, contracthost.MachinePortNameExec) + return nil, err + } + + record.Ports = setMachineExecRelayPort(record.Ports, execRelayPort) + if err := d.store.UpdateMachine(ctx, *record); err != nil { + d.stopMachineRelayProxy(record.ID, contracthost.MachinePortNameExec) + return nil, err + } + return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil +} diff --git a/internal/daemon/exec_relay_test.go b/internal/daemon/exec_relay_test.go new file mode 100644 index 0000000..942083c --- /dev/null +++ b/internal/daemon/exec_relay_test.go @@ -0,0 +1,85 @@ +package daemon + +import ( + "context" + "net" + "testing" + + "github.com/getcompanion-ai/computer-host/internal/model" + "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func TestEnsureExecRelayAllocatesRelayLazily(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + upstream, err := net.Listen("tcp", "127.0.0.1:49983") + if err != nil { + t.Fatalf("listen upstream: %v", err) + } + defer func() { _ = upstream.Close() }() + + record := model.MachineRecord{ + ID: "vm-exec", + Artifact: contracthost.ArtifactRef{KernelImageURL: "https://example.com/kernel", RootFSURL: "https://example.com/rootfs"}, + SystemVolumeID: "vm-exec-system", + RuntimeHost: "127.0.0.1", + Ports: defaultMachinePorts(), + Phase: contracthost.MachinePhaseRunning, + GuestConfig: &contracthost.GuestConfig{}, + } + if err := fileStore.CreateMachine(context.Background(), record); err != nil { + t.Fatalf("create machine record: %v", err) + } + + response, err := hostDaemon.EnsureExecRelay(context.Background(), "vm-exec") + if err != nil { + t.Fatalf("ensure exec relay: %v", err) + } + defer hostDaemon.stopMachineRelays("vm-exec") + + var execPort contracthost.MachinePort + found := false + for _, port := range response.Machine.Ports { + if port.Name == contracthost.MachinePortNameExec { + execPort = port + found = true + break + } + } + if !found { + t.Fatalf("exec port not found in machine ports: %#v", response.Machine.Ports) + } + if execPort.Port != defaultGuestdPort { + t.Fatalf("exec guest port = %d, want %d", execPort.Port, defaultGuestdPort) + } + if execPort.HostPort < minMachineExecRelayPort || execPort.HostPort > maxMachineExecRelayPort { + t.Fatalf("exec host port = %d, want range %d-%d", execPort.HostPort, minMachineExecRelayPort, maxMachineExecRelayPort) + } + + stored, err := fileStore.GetMachine(context.Background(), "vm-exec") + if err != nil { + t.Fatalf("get stored machine: %v", err) + } + hasStoredExecPort := false + for _, port := range stored.Ports { + if port.Name == contracthost.MachinePortNameExec && port.HostPort == execPort.HostPort { + hasStoredExecPort = true + break + } + } + if !hasStoredExecPort { + t.Fatalf("stored machine missing exec relay port: %#v", stored.Ports) + } +} diff --git a/internal/daemon/files.go b/internal/daemon/files.go index 1a5f741..953640b 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -300,16 +300,42 @@ func isZeroChunk(chunk []byte) bool { } func defaultMachinePorts() []contracthost.MachinePort { - return buildMachinePorts(0, 0) + return buildMachinePorts(0, 0, 0) } -func buildMachinePorts(sshRelayPort, vncRelayPort uint16) []contracthost.MachinePort { +func buildMachinePorts(sshRelayPort, vncRelayPort, execRelayPort uint16) []contracthost.MachinePort { return []contracthost.MachinePort{ {Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP}, {Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP}, + {Name: contracthost.MachinePortNameExec, Port: defaultGuestdPort, HostPort: execRelayPort, Protocol: contracthost.PortProtocolTCP}, } } +func setMachineExecRelayPort(ports []contracthost.MachinePort, relayPort uint16) []contracthost.MachinePort { + updated := make([]contracthost.MachinePort, 0, len(ports)) + replaced := false + for _, port := range ports { + if port.Name == contracthost.MachinePortNameExec { + port.HostPort = relayPort + if port.Port == 0 { + port.Port = defaultGuestdPort + } + port.Protocol = contracthost.PortProtocolTCP + replaced = true + } + updated = append(updated, port) + } + if !replaced { + updated = append(updated, contracthost.MachinePort{ + Name: contracthost.MachinePortNameExec, + Port: defaultGuestdPort, + HostPort: relayPort, + Protocol: contracthost.PortProtocolTCP, + }) + } + return updated +} + func (d *Daemon) ensureBackendSSHKeyPair() error { privateKeyPath := d.backendSSHPrivateKeyPath() publicKeyPath := d.backendSSHPublicKeyPath() diff --git a/internal/daemon/machine_relays.go b/internal/daemon/machine_relays.go index cb8801c..ffad01f 100644 --- a/internal/daemon/machine_relays.go +++ b/internal/daemon/machine_relays.go @@ -12,10 +12,12 @@ import ( ) const ( - minMachineSSHRelayPort = uint16(40000) - maxMachineSSHRelayPort = uint16(44999) - minMachineVNCRelayPort = uint16(45000) - maxMachineVNCRelayPort = uint16(49999) + minMachineSSHRelayPort = uint16(40000) + maxMachineSSHRelayPort = uint16(44999) + minMachineVNCRelayPort = uint16(45000) + maxMachineVNCRelayPort = uint16(49999) + minMachineExecRelayPort = uint16(50000) + maxMachineExecRelayPort = uint16(54999) ) func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string { @@ -40,6 +42,8 @@ func machineRelayGuestPort(record model.MachineRecord, name contracthost.Machine switch name { case contracthost.MachinePortNameVNC: return defaultVNCPort + case contracthost.MachinePortNameExec: + return defaultGuestdPort default: return defaultSSHPort } @@ -126,7 +130,7 @@ func (d *Daemon) ensureMachineRelays(ctx context.Context, record *model.MachineR return err } - record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) + record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort, machineRelayHostPort(*record, contracthost.MachinePortNameExec)) if err := d.store.UpdateMachine(ctx, *record); err != nil { d.stopMachineRelays(record.ID) return err @@ -177,6 +181,7 @@ func (d *Daemon) stopMachineRelayProxy(machineID contracthost.MachineID, name co func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) { d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH) d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC) + d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameExec) } func isAddrInUseError(err error) bool { diff --git a/internal/daemon/review_regressions_test.go b/internal/daemon/review_regressions_test.go index 3fe8361..eb4eae0 100644 --- a/internal/daemon/review_regressions_test.go +++ b/internal/daemon/review_regressions_test.go @@ -915,7 +915,7 @@ func exhaustedMachineRelayRecords() []model.MachineRecord { for i := 0; i < count; i++ { machines = append(machines, model.MachineRecord{ ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)), - Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i)), + Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i), 0), Phase: contracthost.MachinePhaseRunning, }) } diff --git a/internal/httpapi/handlers.go b/internal/httpapi/handlers.go index de4a308..a865cdb 100644 --- a/internal/httpapi/handlers.go +++ b/internal/httpapi/handlers.go @@ -15,6 +15,7 @@ type Service interface { GetMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) ListMachines(context.Context) (*contracthost.ListMachinesResponse, error) StartMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) + EnsureExecRelay(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) StopMachine(context.Context, contracthost.MachineID) error DeleteMachine(context.Context, contracthost.MachineID) error Health(context.Context) (*contracthost.HealthResponse, error) @@ -166,6 +167,20 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 2 && parts[1] == "exec-relay" { + if r.Method != http.MethodPost { + writeMethodNotAllowed(w) + return + } + response, err := h.service.EnsureExecRelay(r.Context(), machineID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + return + } + if len(parts) == 2 && parts[1] == "snapshots" { switch r.Method { case http.MethodGet: