feat: phase 1

This commit is contained in:
Harivansh Rathi 2026-04-12 22:47:51 +00:00
parent 4a9dc91ebf
commit 6489e270ce
9 changed files with 194 additions and 13 deletions

View file

@ -7,6 +7,7 @@ type PortProtocol string
const ( const (
MachinePortNameSSH MachinePortName = "ssh" MachinePortNameSSH MachinePortName = "ssh"
MachinePortNameVNC MachinePortName = "vnc" MachinePortNameVNC MachinePortName = "vnc"
MachinePortNameExec MachinePortName = "exec"
) )
const ( const (

View file

@ -23,6 +23,7 @@ const (
defaultGuestDiskSizeBytes = int64(10 * 1024 * 1024 * 1024) // 10 GB defaultGuestDiskSizeBytes = int64(10 * 1024 * 1024 * 1024) // 10 GB
defaultSSHPort = uint16(2222) defaultSSHPort = uint16(2222)
defaultVNCPort = uint16(6080) defaultVNCPort = uint16(6080)
defaultGuestdPort = uint16(49983)
defaultCopyBufferSize = 1024 * 1024 defaultCopyBufferSize = 1024 * 1024
defaultGuestDialTimeout = 500 * time.Millisecond defaultGuestDialTimeout = 500 * time.Millisecond
defaultGuestStopTimeout = 10 * time.Second defaultGuestStopTimeout = 10 * time.Second

View file

@ -158,10 +158,10 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
if response.Machine.RuntimeHost != "127.0.0.1" { if response.Machine.RuntimeHost != "127.0.0.1" {
t.Fatalf("runtime host mismatch: got %q", response.Machine.RuntimeHost) t.Fatalf("runtime host mismatch: got %q", response.Machine.RuntimeHost)
} }
if len(response.Machine.Ports) != 2 { if len(response.Machine.Ports) != 3 {
t.Fatalf("machine ports mismatch: got %d want 2", len(response.Machine.Ports)) t.Fatalf("machine ports mismatch: got %d want 3", len(response.Machine.Ports))
} }
if response.Machine.Ports[0].Port != defaultSSHPort || response.Machine.Ports[1].Port != defaultVNCPort { if response.Machine.Ports[0].Port != defaultSSHPort || response.Machine.Ports[1].Port != defaultVNCPort || response.Machine.Ports[2].Port != defaultGuestdPort {
t.Fatalf("machine ports mismatch: got %#v", response.Machine.Ports) t.Fatalf("machine ports mismatch: got %#v", response.Machine.Ports)
} }
if runtime.bootCalls != 1 { if runtime.bootCalls != 1 {

View file

@ -0,0 +1,48 @@
package daemon
import (
"context"
"fmt"
"strings"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) EnsureExecRelay(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
unlock := d.lockMachine(id)
defer unlock()
record, err := d.store.GetMachine(ctx, id)
if err != nil {
return nil, err
}
if record.Phase != contracthost.MachinePhaseRunning {
return nil, fmt.Errorf("machine %q is not running", id)
}
if strings.TrimSpace(record.RuntimeHost) == "" {
return nil, fmt.Errorf("machine %q runtime host is unavailable", id)
}
d.relayAllocMu.Lock()
execRelayPort, err := d.allocateMachineRelayProxy(
ctx,
*record,
contracthost.MachinePortNameExec,
record.RuntimeHost,
defaultGuestdPort,
minMachineExecRelayPort,
maxMachineExecRelayPort,
)
d.relayAllocMu.Unlock()
if err != nil {
d.stopMachineRelayProxy(record.ID, contracthost.MachinePortNameExec)
return nil, err
}
record.Ports = setMachineExecRelayPort(record.Ports, execRelayPort)
if err := d.store.UpdateMachine(ctx, *record); err != nil {
d.stopMachineRelayProxy(record.ID, contracthost.MachinePortNameExec)
return nil, err
}
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}

View file

@ -0,0 +1,85 @@
package daemon
import (
"context"
"net"
"testing"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func TestEnsureExecRelayAllocatesRelayLazily(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
upstream, err := net.Listen("tcp", "127.0.0.1:49983")
if err != nil {
t.Fatalf("listen upstream: %v", err)
}
defer func() { _ = upstream.Close() }()
record := model.MachineRecord{
ID: "vm-exec",
Artifact: contracthost.ArtifactRef{KernelImageURL: "https://example.com/kernel", RootFSURL: "https://example.com/rootfs"},
SystemVolumeID: "vm-exec-system",
RuntimeHost: "127.0.0.1",
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseRunning,
GuestConfig: &contracthost.GuestConfig{},
}
if err := fileStore.CreateMachine(context.Background(), record); err != nil {
t.Fatalf("create machine record: %v", err)
}
response, err := hostDaemon.EnsureExecRelay(context.Background(), "vm-exec")
if err != nil {
t.Fatalf("ensure exec relay: %v", err)
}
defer hostDaemon.stopMachineRelays("vm-exec")
var execPort contracthost.MachinePort
found := false
for _, port := range response.Machine.Ports {
if port.Name == contracthost.MachinePortNameExec {
execPort = port
found = true
break
}
}
if !found {
t.Fatalf("exec port not found in machine ports: %#v", response.Machine.Ports)
}
if execPort.Port != defaultGuestdPort {
t.Fatalf("exec guest port = %d, want %d", execPort.Port, defaultGuestdPort)
}
if execPort.HostPort < minMachineExecRelayPort || execPort.HostPort > maxMachineExecRelayPort {
t.Fatalf("exec host port = %d, want range %d-%d", execPort.HostPort, minMachineExecRelayPort, maxMachineExecRelayPort)
}
stored, err := fileStore.GetMachine(context.Background(), "vm-exec")
if err != nil {
t.Fatalf("get stored machine: %v", err)
}
hasStoredExecPort := false
for _, port := range stored.Ports {
if port.Name == contracthost.MachinePortNameExec && port.HostPort == execPort.HostPort {
hasStoredExecPort = true
break
}
}
if !hasStoredExecPort {
t.Fatalf("stored machine missing exec relay port: %#v", stored.Ports)
}
}

View file

@ -300,16 +300,42 @@ func isZeroChunk(chunk []byte) bool {
} }
func defaultMachinePorts() []contracthost.MachinePort { func defaultMachinePorts() []contracthost.MachinePort {
return buildMachinePorts(0, 0) return buildMachinePorts(0, 0, 0)
} }
func buildMachinePorts(sshRelayPort, vncRelayPort uint16) []contracthost.MachinePort { func buildMachinePorts(sshRelayPort, vncRelayPort, execRelayPort uint16) []contracthost.MachinePort {
return []contracthost.MachinePort{ return []contracthost.MachinePort{
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP}, {Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, HostPort: sshRelayPort, Protocol: contracthost.PortProtocolTCP},
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP}, {Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, HostPort: vncRelayPort, Protocol: contracthost.PortProtocolTCP},
{Name: contracthost.MachinePortNameExec, Port: defaultGuestdPort, HostPort: execRelayPort, Protocol: contracthost.PortProtocolTCP},
} }
} }
func setMachineExecRelayPort(ports []contracthost.MachinePort, relayPort uint16) []contracthost.MachinePort {
updated := make([]contracthost.MachinePort, 0, len(ports))
replaced := false
for _, port := range ports {
if port.Name == contracthost.MachinePortNameExec {
port.HostPort = relayPort
if port.Port == 0 {
port.Port = defaultGuestdPort
}
port.Protocol = contracthost.PortProtocolTCP
replaced = true
}
updated = append(updated, port)
}
if !replaced {
updated = append(updated, contracthost.MachinePort{
Name: contracthost.MachinePortNameExec,
Port: defaultGuestdPort,
HostPort: relayPort,
Protocol: contracthost.PortProtocolTCP,
})
}
return updated
}
func (d *Daemon) ensureBackendSSHKeyPair() error { func (d *Daemon) ensureBackendSSHKeyPair() error {
privateKeyPath := d.backendSSHPrivateKeyPath() privateKeyPath := d.backendSSHPrivateKeyPath()
publicKeyPath := d.backendSSHPublicKeyPath() publicKeyPath := d.backendSSHPublicKeyPath()

View file

@ -16,6 +16,8 @@ const (
maxMachineSSHRelayPort = uint16(44999) maxMachineSSHRelayPort = uint16(44999)
minMachineVNCRelayPort = uint16(45000) minMachineVNCRelayPort = uint16(45000)
maxMachineVNCRelayPort = uint16(49999) maxMachineVNCRelayPort = uint16(49999)
minMachineExecRelayPort = uint16(50000)
maxMachineExecRelayPort = uint16(54999)
) )
func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string { func machineRelayListenerKey(machineID contracthost.MachineID, name contracthost.MachinePortName) string {
@ -40,6 +42,8 @@ func machineRelayGuestPort(record model.MachineRecord, name contracthost.Machine
switch name { switch name {
case contracthost.MachinePortNameVNC: case contracthost.MachinePortNameVNC:
return defaultVNCPort return defaultVNCPort
case contracthost.MachinePortNameExec:
return defaultGuestdPort
default: default:
return defaultSSHPort return defaultSSHPort
} }
@ -126,7 +130,7 @@ func (d *Daemon) ensureMachineRelays(ctx context.Context, record *model.MachineR
return err return err
} }
record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort) record.Ports = buildMachinePorts(sshRelayPort, vncRelayPort, machineRelayHostPort(*record, contracthost.MachinePortNameExec))
if err := d.store.UpdateMachine(ctx, *record); err != nil { if err := d.store.UpdateMachine(ctx, *record); err != nil {
d.stopMachineRelays(record.ID) d.stopMachineRelays(record.ID)
return err return err
@ -177,6 +181,7 @@ func (d *Daemon) stopMachineRelayProxy(machineID contracthost.MachineID, name co
func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) { func (d *Daemon) stopMachineRelays(machineID contracthost.MachineID) {
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH) d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameSSH)
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC) d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameVNC)
d.stopMachineRelayProxy(machineID, contracthost.MachinePortNameExec)
} }
func isAddrInUseError(err error) bool { func isAddrInUseError(err error) bool {

View file

@ -915,7 +915,7 @@ func exhaustedMachineRelayRecords() []model.MachineRecord {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
machines = append(machines, model.MachineRecord{ machines = append(machines, model.MachineRecord{
ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)), ID: contracthost.MachineID(fmt.Sprintf("relay-exhausted-%d", i)),
Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i)), Ports: buildMachinePorts(minMachineSSHRelayPort+uint16(i), minMachineVNCRelayPort+uint16(i), 0),
Phase: contracthost.MachinePhaseRunning, Phase: contracthost.MachinePhaseRunning,
}) })
} }

View file

@ -15,6 +15,7 @@ type Service interface {
GetMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) GetMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error)
ListMachines(context.Context) (*contracthost.ListMachinesResponse, error) ListMachines(context.Context) (*contracthost.ListMachinesResponse, error)
StartMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) StartMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error)
EnsureExecRelay(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error)
StopMachine(context.Context, contracthost.MachineID) error StopMachine(context.Context, contracthost.MachineID) error
DeleteMachine(context.Context, contracthost.MachineID) error DeleteMachine(context.Context, contracthost.MachineID) error
Health(context.Context) (*contracthost.HealthResponse, error) Health(context.Context) (*contracthost.HealthResponse, error)
@ -166,6 +167,20 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(parts) == 2 && parts[1] == "exec-relay" {
if r.Method != http.MethodPost {
writeMethodNotAllowed(w)
return
}
response, err := h.service.EnsureExecRelay(r.Context(), machineID)
if err != nil {
writeError(w, statusForError(err), err)
return
}
writeJSON(w, http.StatusOK, response)
return
}
if len(parts) == 2 && parts[1] == "snapshots" { if len(parts) == 2 && parts[1] == "snapshots" {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet: