diff --git a/contract/published_ports.go b/contract/published_ports.go new file mode 100644 index 0000000..59b95a4 --- /dev/null +++ b/contract/published_ports.go @@ -0,0 +1,30 @@ +package host + +import "time" + +type PublishedPortID string + +type PublishedPort struct { + ID PublishedPortID `json:"id"` + MachineID MachineID `json:"machine_id"` + Name string `json:"name,omitempty"` + Port uint16 `json:"port"` + HostPort uint16 `json:"host_port"` + Protocol PortProtocol `json:"protocol"` + CreatedAt time.Time `json:"created_at"` +} + +type CreatePublishedPortRequest struct { + PublishedPortID PublishedPortID `json:"published_port_id"` + Name string `json:"name,omitempty"` + Port uint16 `json:"port"` + Protocol PortProtocol `json:"protocol"` +} + +type CreatePublishedPortResponse struct { + Port PublishedPort `json:"port"` +} + +type ListPublishedPortsResponse struct { + Ports []PublishedPort `json:"ports"` +} diff --git a/contract/snapshots.go b/contract/snapshots.go index 96b6d5c..ea1377a 100644 --- a/contract/snapshots.go +++ b/contract/snapshots.go @@ -10,6 +10,10 @@ type Snapshot struct { CreatedAt time.Time `json:"created_at"` } +type CreateSnapshotRequest struct { + SnapshotID SnapshotID `json:"snapshot_id"` +} + type CreateSnapshotResponse struct { Snapshot Snapshot `json:"snapshot"` } diff --git a/contract/storage.go b/contract/storage.go index 553077b..bdf1993 100644 --- a/contract/storage.go +++ b/contract/storage.go @@ -13,3 +13,43 @@ type Volume struct { AttachedMachineID *MachineID `json:"attached_machine_id,omitempty"` CreatedAt time.Time `json:"created_at"` } + +type StoragePool string + +const ( + StoragePoolArtifacts StoragePool = "artifacts" + StoragePoolMachineDisks StoragePool = "machine-disks" + StoragePoolPublishedPort StoragePool = "published-ports" + StoragePoolSnapshots StoragePool = "snapshots" + StoragePoolState StoragePool = "state" +) + +type StoragePoolUsage struct { + Pool StoragePool `json:"pool"` + Bytes int64 `json:"bytes"` +} + +type MachineStorageUsage struct { + MachineID MachineID `json:"machine_id"` + SystemBytes int64 `json:"system_bytes"` + UserBytes int64 `json:"user_bytes"` + RuntimeBytes int64 `json:"runtime_bytes"` +} + +type SnapshotStorageUsage struct { + SnapshotID SnapshotID `json:"snapshot_id"` + Bytes int64 `json:"bytes"` +} + +type StorageReport struct { + GeneratedAt time.Time `json:"generated_at"` + TotalBytes int64 `json:"total_bytes"` + Pools []StoragePoolUsage `json:"pools,omitempty"` + Machines []MachineStorageUsage `json:"machines,omitempty"` + Snapshots []SnapshotStorageUsage `json:"snapshots,omitempty"` + PublishedPorts int64 `json:"published_ports"` +} + +type GetStorageReportResponse struct { + Report StorageReport `json:"report"` +} diff --git a/internal/config/config.go b/internal/config/config.go index b5a2411..7d64a09 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ type Config struct { SnapshotsDir string RuntimeDir string SocketPath string + HTTPAddr string EgressInterface string FirecrackerBinaryPath string JailerBinaryPath string @@ -38,6 +39,7 @@ func Load() (Config, error) { SnapshotsDir: filepath.Join(rootDir, "snapshots"), RuntimeDir: filepath.Join(rootDir, "runtime"), SocketPath: filepath.Join(rootDir, defaultSocketName), + HTTPAddr: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_HTTP_ADDR")), EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")), FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")), JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")), diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 46dd9e5..117af30 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -3,6 +3,7 @@ package daemon import ( "context" "fmt" + "net" "os" "sync" "time" @@ -45,6 +46,9 @@ type Daemon struct { locksMu sync.Mutex machineLocks map[contracthost.MachineID]*sync.Mutex artifactLocks map[string]*sync.Mutex + + publishedPortsMu sync.Mutex + publishedPortListeners map[contracthost.PublishedPortID]net.Listener } func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, error) { @@ -69,6 +73,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err reconfigureGuestIdentity: nil, machineLocks: make(map[contracthost.MachineID]*sync.Mutex), artifactLocks: make(map[string]*sync.Mutex), + publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener), } daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH if err := daemon.ensureBackendSSHKeyPair(); err != nil { diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index e347b69..3d33345 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -73,9 +73,13 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { } sshListener := listenTestPort(t, int(defaultSSHPort)) - defer sshListener.Close() + defer func() { + _ = sshListener.Close() + }() vncListener := listenTestPort(t, int(defaultVNCPort)) - defer vncListener.Close() + defer func() { + _ = vncListener.Close() + }() startedAt := time.Unix(1700000005, 0).UTC() runtime := &fakeRuntime{ @@ -339,9 +343,13 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) { } sshListener := listenTestPort(t, int(defaultSSHPort)) - defer sshListener.Close() + defer func() { + _ = sshListener.Close() + }() vncListener := listenTestPort(t, int(defaultVNCPort)) - defer vncListener.Close() + defer func() { + _ = vncListener.Close() + }() startedAt := time.Unix(1700000099, 0).UTC() runtime := &fakeRuntime{ diff --git a/internal/daemon/files.go b/internal/daemon/files.go index 472363b..b518f54 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -53,7 +53,9 @@ func cloneFile(source string, target string) error { if err != nil { return fmt.Errorf("open source file %q: %w", source, err) } - defer sourceFile.Close() + defer func() { + _ = sourceFile.Close() + }() sourceInfo, err := sourceFile.Stat() if err != nil { @@ -67,15 +69,15 @@ func cloneFile(source string, target string) error { } if _, err := writeSparseFile(targetFile, sourceFile); err != nil { - targetFile.Close() + _ = targetFile.Close() return fmt.Errorf("copy %q to %q: %w", source, tmpPath, err) } if err := targetFile.Truncate(sourceInfo.Size()); err != nil { - targetFile.Close() + _ = targetFile.Close() return fmt.Errorf("truncate target file %q: %w", tmpPath, err) } if err := targetFile.Sync(); err != nil { - targetFile.Close() + _ = targetFile.Close() return fmt.Errorf("sync target file %q: %w", tmpPath, err) } if err := targetFile.Close(); err != nil { @@ -108,7 +110,9 @@ func downloadFile(ctx context.Context, rawURL string, path string) error { if err != nil { return fmt.Errorf("download %q: %w", rawURL, err) } - defer response.Body.Close() + defer func() { + _ = response.Body.Close() + }() if response.StatusCode != http.StatusOK { return fmt.Errorf("download %q: status %d", rawURL, response.StatusCode) } @@ -121,15 +125,15 @@ func downloadFile(ctx context.Context, rawURL string, path string) error { size, err := writeSparseFile(file, response.Body) if err != nil { - file.Close() + _ = file.Close() return fmt.Errorf("write download target %q: %w", tmpPath, err) } if err := file.Truncate(size); err != nil { - file.Close() + _ = file.Close() return fmt.Errorf("truncate download target %q: %w", tmpPath, err) } if err := file.Sync(); err != nil { - file.Close() + _ = file.Close() return fmt.Errorf("sync download target %q: %w", tmpPath, err) } if err := file.Close(); err != nil { @@ -267,7 +271,9 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho if err != nil { return fmt.Errorf("create guest config staging dir: %w", err) } - defer os.RemoveAll(stagingDir) + defer func() { + _ = os.RemoveAll(stagingDir) + }() if len(config.AuthorizedKeys) > 0 { authorizedKeysPath := filepath.Join(stagingDir, "authorized_keys") @@ -306,7 +312,9 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont if err != nil { return fmt.Errorf("create machine identity staging dir: %w", err) } - defer os.RemoveAll(stagingDir) + defer func() { + _ = os.RemoveAll(stagingDir) + }() identityFiles := map[string]string{ "/etc/microagent/machine-name": machineName + "\n", @@ -368,6 +376,18 @@ func machineToContract(record model.MachineRecord) contracthost.Machine { } } +func publishedPortToContract(record model.PublishedPortRecord) contracthost.PublishedPort { + return contracthost.PublishedPort{ + ID: record.ID, + MachineID: record.MachineID, + Name: record.Name, + Port: record.Port, + HostPort: record.HostPort, + Protocol: record.Protocol, + CreatedAt: record.CreatedAt, + } +} + func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState { phase := firecracker.PhaseStopped switch record.Phase { @@ -426,6 +446,13 @@ func validateMachineID(machineID contracthost.MachineID) error { return nil } +func validateSnapshotID(snapshotID contracthost.SnapshotID) error { + if strings.TrimSpace(string(snapshotID)) == "" { + return fmt.Errorf("snapshot_id is required") + } + return nil +} + func validateDownloadURL(field string, raw string) error { value := strings.TrimSpace(raw) if value == "" { @@ -450,7 +477,7 @@ func syncDir(path string) error { return fmt.Errorf("open dir %q: %w", path, err) } if err := dir.Sync(); err != nil { - dir.Close() + _ = dir.Close() return fmt.Errorf("sync dir %q: %w", path, err) } if err := dir.Close(); err != nil { diff --git a/internal/daemon/files_test.go b/internal/daemon/files_test.go index fe197a6..f9458f8 100644 --- a/internal/daemon/files_test.go +++ b/internal/daemon/files_test.go @@ -19,15 +19,15 @@ func TestCloneFilePreservesSparseDiskUsage(t *testing.T) { t.Fatalf("open source file: %v", err) } if _, err := sourceFile.Write([]byte("head")); err != nil { - sourceFile.Close() + _ = sourceFile.Close() t.Fatalf("write source prefix: %v", err) } if _, err := sourceFile.Seek(32<<20, io.SeekStart); err != nil { - sourceFile.Close() + _ = sourceFile.Close() t.Fatalf("seek source hole: %v", err) } if _, err := sourceFile.Write([]byte("tail")); err != nil { - sourceFile.Close() + _ = sourceFile.Close() t.Fatalf("write source suffix: %v", err) } if err := sourceFile.Close(); err != nil { diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go index 580d0d6..883f19a 100644 --- a/internal/daemon/lifecycle.go +++ b/internal/daemon/lifecycle.go @@ -39,6 +39,88 @@ func (d *Daemon) ListMachines(ctx context.Context) (*contracthost.ListMachinesRe return &contracthost.ListMachinesResponse{Machines: machines}, nil } +func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) { + unlock := d.lockMachine(id) + defer unlock() + + record, err := d.store.GetMachine(ctx, id) + if err != nil { + return nil, err + } + if record.Phase == contracthost.MachinePhaseRunning { + return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil + } + if record.Phase != contracthost.MachinePhaseStopped { + return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase) + } + + if err := d.store.UpsertOperation(ctx, model.OperationRecord{ + MachineID: id, + Type: model.MachineOperationStart, + StartedAt: time.Now().UTC(), + }); err != nil { + return nil, err + } + + clearOperation := false + defer func() { + if clearOperation { + _ = d.store.DeleteOperation(context.Background(), id) + } + }() + + systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID) + if err != nil { + return nil, err + } + artifact, err := d.store.GetArtifact(ctx, record.Artifact) + if err != nil { + return nil, err + } + userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs) + if err != nil { + return nil, err + } + spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path) + if err != nil { + return nil, err + } + usedNetworks, err := d.listRunningNetworks(ctx, id) + if err != nil { + return nil, err + } + state, err := d.runtime.Boot(ctx, spec, usedNetworks) + if err != nil { + return nil, err + } + ports := defaultMachinePorts() + if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil { + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + + record.RuntimeHost = state.RuntimeHost + record.TapDevice = state.TapName + record.Ports = ports + record.Phase = contracthost.MachinePhaseRunning + record.Error = "" + record.PID = state.PID + record.SocketPath = state.SocketPath + record.StartedAt = state.StartedAt + if err := d.store.UpdateMachine(ctx, *record); err != nil { + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { + d.stopPublishedPortsForMachine(id) + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + + clearOperation = true + return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil +} + func (d *Daemon) StopMachine(ctx context.Context, id contracthost.MachineID) error { unlock := d.lockMachine(id) defer unlock() @@ -120,6 +202,10 @@ func (d *Daemon) Reconcile(ctx context.Context) error { if err := d.reconcileCreate(ctx, operation.MachineID); err != nil { return err } + case model.MachineOperationStart: + if err := d.reconcileStart(ctx, operation.MachineID); err != nil { + return err + } case model.MachineOperationStop: if err := d.reconcileStop(ctx, operation.MachineID); err != nil { return err @@ -149,6 +235,13 @@ func (d *Daemon) Reconcile(ctx context.Context) error { if _, err := d.reconcileMachine(ctx, record.ID); err != nil { return err } + if record.Phase == contracthost.MachinePhaseRunning { + if err := d.ensurePublishedPortsForMachine(ctx, record); err != nil { + return err + } + } else { + d.stopPublishedPortsForMachine(record.ID) + } } return nil } @@ -218,6 +311,26 @@ func (d *Daemon) reconcileStop(ctx context.Context, machineID contracthost.Machi return d.store.DeleteOperation(ctx, machineID) } +func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.MachineID) error { + record, err := d.store.GetMachine(ctx, machineID) + if err == store.ErrNotFound { + return d.store.DeleteOperation(ctx, machineID) + } + if err != nil { + return err + } + if record.Phase == contracthost.MachinePhaseRunning { + if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil { + return err + } + return d.store.DeleteOperation(ctx, machineID) + } + if _, err := d.StartMachine(ctx, machineID); err != nil { + return err + } + return d.store.DeleteOperation(ctx, machineID) +} + func (d *Daemon) reconcileDelete(ctx context.Context, machineID contracthost.MachineID) error { record, err := d.store.GetMachine(ctx, machineID) if err == store.ErrNotFound { @@ -266,6 +379,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma if err := d.runtime.Delete(ctx, *state); err != nil { return nil, err } + d.stopPublishedPortsForMachine(record.ID) record.Phase = contracthost.MachinePhaseFailed record.Error = state.Error record.PID = 0 @@ -280,9 +394,15 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma } func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error { + d.stopPublishedPortsForMachine(record.ID) if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { return err } + if ports, err := d.store.ListPublishedPorts(ctx, record.ID); err == nil { + for _, port := range ports { + _ = d.store.DeletePublishedPort(ctx, port.ID) + } + } if err := d.detachVolumesForMachine(ctx, record.ID); err != nil { return err } @@ -304,6 +424,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR } func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error { + d.stopPublishedPortsForMachine(record.ID) if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { return err } diff --git a/internal/daemon/published_ports.go b/internal/daemon/published_ports.go new file mode 100644 index 0000000..d55bec5 --- /dev/null +++ b/internal/daemon/published_ports.go @@ -0,0 +1,246 @@ +package daemon + +import ( + "context" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/getcompanion-ai/computer-host/internal/model" + "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +const ( + minPublishedHostPort = uint16(20000) + maxPublishedHostPort = uint16(39999) +) + +func (d *Daemon) CreatePublishedPort(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error) { + if strings.TrimSpace(string(req.PublishedPortID)) == "" { + return nil, fmt.Errorf("published_port_id is required") + } + if req.Port == 0 { + return nil, fmt.Errorf("port must be greater than zero") + } + if req.Protocol == "" { + req.Protocol = contracthost.PortProtocolTCP + } + if req.Protocol != contracthost.PortProtocolTCP { + return nil, fmt.Errorf("unsupported protocol %q", req.Protocol) + } + + unlock := d.lockMachine(machineID) + defer unlock() + + record, err := d.store.GetMachine(ctx, machineID) + if err != nil { + return nil, err + } + if record.Phase != contracthost.MachinePhaseRunning { + return nil, fmt.Errorf("machine %q is not running", machineID) + } + if _, err := d.store.GetPublishedPort(ctx, req.PublishedPortID); err == nil { + return nil, fmt.Errorf("published port %q already exists", req.PublishedPortID) + } else if err != nil && err != store.ErrNotFound { + return nil, err + } + + hostPort, err := d.allocatePublishedHostPort(ctx) + if err != nil { + return nil, err + } + + published := model.PublishedPortRecord{ + ID: req.PublishedPortID, + MachineID: machineID, + Name: strings.TrimSpace(req.Name), + Port: req.Port, + HostPort: hostPort, + Protocol: req.Protocol, + CreatedAt: time.Now().UTC(), + } + if err := d.startPublishedPortProxy(published, record.RuntimeHost); err != nil { + return nil, err + } + storeCreated := false + defer func() { + if !storeCreated { + d.stopPublishedPortProxy(req.PublishedPortID) + } + }() + + if err := d.store.CreatePublishedPort(ctx, published); err != nil { + return nil, err + } + storeCreated = true + return &contracthost.CreatePublishedPortResponse{Port: publishedPortToContract(published)}, nil +} + +func (d *Daemon) ListPublishedPorts(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListPublishedPortsResponse, error) { + ports, err := d.store.ListPublishedPorts(ctx, machineID) + if err != nil { + return nil, err + } + response := &contracthost.ListPublishedPortsResponse{Ports: make([]contracthost.PublishedPort, 0, len(ports))} + for _, port := range ports { + response.Ports = append(response.Ports, publishedPortToContract(port)) + } + return response, nil +} + +func (d *Daemon) DeletePublishedPort(ctx context.Context, machineID contracthost.MachineID, portID contracthost.PublishedPortID) error { + unlock := d.lockMachine(machineID) + defer unlock() + + record, err := d.store.GetPublishedPort(ctx, portID) + if err != nil { + if err == store.ErrNotFound { + return nil + } + return err + } + if record.MachineID != machineID { + return fmt.Errorf("published port %q does not belong to machine %q", portID, machineID) + } + d.stopPublishedPortProxy(portID) + return d.store.DeletePublishedPort(ctx, portID) +} + +func (d *Daemon) ensurePublishedPortsForMachine(ctx context.Context, machine model.MachineRecord) error { + if machine.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(machine.RuntimeHost) == "" { + return nil + } + ports, err := d.store.ListPublishedPorts(ctx, machine.ID) + if err != nil { + return err + } + for _, port := range ports { + if err := d.startPublishedPortProxy(port, machine.RuntimeHost); err != nil { + return err + } + } + return nil +} + +func (d *Daemon) stopPublishedPortsForMachine(machineID contracthost.MachineID) { + ports, err := d.store.ListPublishedPorts(context.Background(), machineID) + if err != nil { + return + } + for _, port := range ports { + d.stopPublishedPortProxy(port.ID) + } +} + +func (d *Daemon) allocatePublishedHostPort(ctx context.Context) (uint16, error) { + ports, err := d.store.ListPublishedPorts(ctx, "") + if err != nil { + return 0, err + } + used := make(map[uint16]struct{}, len(ports)) + for _, port := range ports { + used[port.HostPort] = struct{}{} + } + for hostPort := minPublishedHostPort; hostPort <= maxPublishedHostPort; hostPort++ { + if _, exists := used[hostPort]; exists { + continue + } + return hostPort, nil + } + return 0, fmt.Errorf("no published host ports are available") +} + +func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtimeHost string) error { + targetHost := strings.TrimSpace(runtimeHost) + if targetHost == "" { + return fmt.Errorf("runtime host is required for published port %q", port.ID) + } + + d.publishedPortsMu.Lock() + if _, exists := d.publishedPortListeners[port.ID]; exists { + d.publishedPortsMu.Unlock() + return nil + } + listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(port.HostPort))) + if err != nil { + d.publishedPortsMu.Unlock() + return fmt.Errorf("listen on host port %d: %w", port.HostPort, err) + } + d.publishedPortListeners[port.ID] = listener + d.publishedPortsMu.Unlock() + + targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port))) + go d.servePublishedPortProxy(port.ID, listener, targetAddr) + return nil +} + +func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) { + for { + conn, err := listener.Accept() + if err != nil { + if isClosedNetworkError(err) { + return + } + continue + } + go proxyPublishedPortConnection(conn, targetAddr) + } +} + +func proxyPublishedPortConnection(source net.Conn, targetAddr string) { + defer func() { + _ = source.Close() + }() + + target, err := net.DialTimeout("tcp", targetAddr, 5*time.Second) + if err != nil { + return + } + defer func() { + _ = target.Close() + }() + + done := make(chan struct{}, 1) + go func() { + _, _ = io.Copy(target, source) + closeWrite(target) + done <- struct{}{} + }() + + _, _ = io.Copy(source, target) + closeWrite(source) + <-done +} + +func closeWrite(conn net.Conn) { + type closeWriter interface { + CloseWrite() error + } + if value, ok := conn.(closeWriter); ok { + _ = value.CloseWrite() + } +} + +func isClosedNetworkError(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(err.Error()) + return strings.Contains(message, "use of closed network connection") +} + +func (d *Daemon) stopPublishedPortProxy(portID contracthost.PublishedPortID) { + d.publishedPortsMu.Lock() + listener, ok := d.publishedPortListeners[portID] + if ok { + delete(d.publishedPortListeners, portID) + } + d.publishedPortsMu.Unlock() + if ok { + _ = listener.Close() + } +} diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index a9da445..d868291 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -2,8 +2,6 @@ package daemon import ( "context" - "crypto/rand" - "encoding/hex" "fmt" "io" "os" @@ -18,10 +16,14 @@ import ( contracthost "github.com/getcompanion-ai/computer-host/contract" ) -func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error) { +func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) { unlock := d.lockMachine(machineID) defer unlock() + if err := validateSnapshotID(req.SnapshotID); err != nil { + return nil, err + } + record, err := d.store.GetMachine(ctx, machineID) if err != nil { return nil, err @@ -39,7 +41,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach } } - snapshotID := contracthost.SnapshotID(generateID()) + snapshotID := req.SnapshotID if err := d.store.UpsertOperation(ctx, model.OperationRecord{ MachineID: machineID, @@ -356,14 +358,6 @@ func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecra return false } -func generateID() string { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("generate id: %v", err)) - } - return hex.EncodeToString(b) -} - // moveFile copies src to dst then removes src. Works across filesystem boundaries // unlike os.Rename, which is needed when moving files out of /proc//root/. func moveFile(src, dst string) error { @@ -371,7 +365,9 @@ func moveFile(src, dst string) error { if err != nil { return err } - defer in.Close() + defer func() { + _ = in.Close() + }() out, err := os.Create(dst) if err != nil { @@ -379,7 +375,7 @@ func moveFile(src, dst string) error { } if _, err := io.Copy(out, in); err != nil { - out.Close() + _ = out.Close() _ = os.Remove(dst) return err } diff --git a/internal/daemon/storage_report.go b/internal/daemon/storage_report.go new file mode 100644 index 0000000..c1ad1b9 --- /dev/null +++ b/internal/daemon/storage_report.go @@ -0,0 +1,141 @@ +package daemon + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "time" + + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) GetStorageReport(ctx context.Context) (*contracthost.GetStorageReportResponse, error) { + volumes, err := d.store.ListVolumes(ctx) + if err != nil { + return nil, err + } + snapshots, err := d.store.ListSnapshots(ctx) + if err != nil { + return nil, err + } + publishedPorts, err := d.store.ListPublishedPorts(ctx, "") + if err != nil { + return nil, err + } + + pools := make([]contracthost.StoragePoolUsage, 0, 5) + totalBytes := int64(0) + addPool := func(pool contracthost.StoragePool, path string) error { + bytes, err := directorySize(path) + if err != nil { + return err + } + pools = append(pools, contracthost.StoragePoolUsage{Pool: pool, Bytes: bytes}) + totalBytes += bytes + return nil + } + + for _, pool := range []struct { + name contracthost.StoragePool + path string + }{ + {name: contracthost.StoragePoolArtifacts, path: d.config.ArtifactsDir}, + {name: contracthost.StoragePoolMachineDisks, path: d.config.MachineDisksDir}, + {name: contracthost.StoragePoolSnapshots, path: d.config.SnapshotsDir}, + {name: contracthost.StoragePoolState, path: filepath.Dir(d.config.StatePath)}, + } { + if err := addPool(pool.name, pool.path); err != nil { + return nil, err + } + } + + machineUsage := make([]contracthost.MachineStorageUsage, 0, len(volumes)) + for _, volume := range volumes { + if volume.AttachedMachineID == nil || volume.Kind != contracthost.VolumeKindSystem { + continue + } + bytes, err := fileSize(volume.Path) + if err != nil { + return nil, err + } + machineUsage = append(machineUsage, contracthost.MachineStorageUsage{ + MachineID: *volume.AttachedMachineID, + SystemBytes: bytes, + }) + } + + snapshotUsage := make([]contracthost.SnapshotStorageUsage, 0, len(snapshots)) + for _, snapshot := range snapshots { + bytes, err := fileSize(snapshot.MemFilePath) + if err != nil { + return nil, err + } + stateBytes, err := fileSize(snapshot.StateFilePath) + if err != nil { + return nil, err + } + bytes += stateBytes + for _, diskPath := range snapshot.DiskPaths { + diskBytes, err := fileSize(diskPath) + if err != nil { + return nil, err + } + bytes += diskBytes + } + snapshotUsage = append(snapshotUsage, contracthost.SnapshotStorageUsage{ + SnapshotID: snapshot.ID, + Bytes: bytes, + }) + } + + return &contracthost.GetStorageReportResponse{ + Report: contracthost.StorageReport{ + GeneratedAt: time.Now().UTC(), + TotalBytes: totalBytes, + Pools: pools, + Machines: machineUsage, + Snapshots: snapshotUsage, + PublishedPorts: int64(len(publishedPorts)), + }, + }, nil +} + +func directorySize(root string) (int64, error) { + if root == "" { + return 0, nil + } + if _, err := os.Stat(root); err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, fmt.Errorf("stat %q: %w", root, err) + } + var total int64 + if err := filepath.WalkDir(root, func(path string, entry fs.DirEntry, err error) error { + if err != nil { + return err + } + if entry.IsDir() { + return nil + } + info, err := entry.Info() + if err != nil { + return err + } + total += info.Size() + return nil + }); err != nil { + return 0, fmt.Errorf("walk %q: %w", root, err) + } + return total, nil +} + +func fileSize(path string) (int64, error) { + info, err := os.Stat(path) + if err != nil { + return 0, fmt.Errorf("stat %q: %w", path, err) + } + return info.Size(), nil +} diff --git a/internal/firecracker/api.go b/internal/firecracker/api.go index 14b7674..d3d96b1 100644 --- a/internal/firecracker/api.go +++ b/internal/firecracker/api.go @@ -231,7 +231,9 @@ func (c *apiClient) do(ctx context.Context, method string, endpoint string, inpu if err != nil { return fmt.Errorf("do %s %s via %q: %w", method, endpoint, c.socketPath, err) } - defer response.Body.Close() + defer func() { + _ = response.Body.Close() + }() if response.StatusCode != wantStatus { return decodeFirecrackerError(method, endpoint, response) diff --git a/internal/firecracker/launch.go b/internal/firecracker/launch.go index d682cd2..3b57c87 100644 --- a/internal/firecracker/launch.go +++ b/internal/firecracker/launch.go @@ -161,7 +161,6 @@ func waitForSocket(ctx context.Context, client *apiClient, socketPath string) er lastPingErr = err continue } - lastPingErr = nil return nil } } @@ -231,13 +230,6 @@ func waitForPIDFile(ctx context.Context, pidFilePath string) (int, error) { } } -func hostVSockPath(paths machinePaths, spec MachineSpec) string { - if spec.Vsock == nil { - return "" - } - return filepath.Join(paths.ChrootRootDir, defaultFirecrackerSocketDir, filepath.Base(strings.TrimSpace(spec.Vsock.Path))) -} - func jailedVSockPath(spec MachineSpec) string { if spec.Vsock == nil { return "" diff --git a/internal/firecracker/paths.go b/internal/firecracker/paths.go index 0f02261..65119cb 100644 --- a/internal/firecracker/paths.go +++ b/internal/firecracker/paths.go @@ -67,18 +67,3 @@ func buildMachinePaths(rootDir string, id MachineID, firecrackerBinaryPath strin func procSocketPath(pid int) string { return filepath.Join("/proc", strconv.Itoa(pid), "root", defaultFirecrackerSocketDir, defaultFirecrackerSocketName) } - -type snapshotPaths struct { - BaseDir string - MemFilePath string - StateFilePath string -} - -func buildSnapshotPaths(rootDir string, id string) snapshotPaths { - baseDir := filepath.Join(rootDir, "snapshots", id) - return snapshotPaths{ - BaseDir: baseDir, - MemFilePath: filepath.Join(baseDir, "memory.bin"), - StateFilePath: filepath.Join(baseDir, "vmstate.bin"), - } -} diff --git a/internal/httpapi/handlers.go b/internal/httpapi/handlers.go index fc6ba7b..49c71db 100644 --- a/internal/httpapi/handlers.go +++ b/internal/httpapi/handlers.go @@ -14,14 +14,19 @@ type Service interface { CreateMachine(context.Context, contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) GetMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) ListMachines(context.Context) (*contracthost.ListMachinesResponse, error) + StartMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error) StopMachine(context.Context, contracthost.MachineID) error DeleteMachine(context.Context, contracthost.MachineID) error Health(context.Context) (*contracthost.HealthResponse, error) - CreateSnapshot(context.Context, contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error) + GetStorageReport(context.Context) (*contracthost.GetStorageReportResponse, error) + CreateSnapshot(context.Context, contracthost.MachineID, contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error) DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) + CreatePublishedPort(context.Context, contracthost.MachineID, contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error) + ListPublishedPorts(context.Context, contracthost.MachineID) (*contracthost.ListPublishedPortsResponse, error) + DeletePublishedPort(context.Context, contracthost.MachineID, contracthost.PublishedPortID) error } type Handler struct { @@ -38,6 +43,7 @@ func New(service Service) (*Handler, error) { func (h *Handler) Routes() http.Handler { mux := http.NewServeMux() mux.HandleFunc("/health", h.handleHealth) + mux.HandleFunc("/storage/report", h.handleStorageReport) mux.HandleFunc("/machines", h.handleMachines) mux.HandleFunc("/machines/", h.handleMachine) mux.HandleFunc("/snapshots/", h.handleSnapshot) @@ -57,6 +63,19 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, response) } +func (h *Handler) handleStorageReport(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeMethodNotAllowed(w) + return + } + response, err := h.service.GetStorageReport(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, response) +} + func (h *Handler) handleMachines(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: @@ -126,6 +145,20 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 2 && parts[1] == "start" { + if r.Method != http.MethodPost { + writeMethodNotAllowed(w) + return + } + response, err := h.service.StartMachine(r.Context(), machineID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + return + } + if len(parts) == 2 && parts[1] == "snapshots" { switch r.Method { case http.MethodGet: @@ -136,7 +169,12 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, response) case http.MethodPost: - response, err := h.service.CreateSnapshot(r.Context(), machineID) + var request contracthost.CreateSnapshotRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + response, err := h.service.CreateSnapshot(r.Context(), machineID, request) if err != nil { writeError(w, statusForError(err), err) return @@ -148,6 +186,46 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 2 && parts[1] == "published-ports" { + switch r.Method { + case http.MethodGet: + response, err := h.service.ListPublishedPorts(r.Context(), machineID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + case http.MethodPost: + var request contracthost.CreatePublishedPortRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + response, err := h.service.CreatePublishedPort(r.Context(), machineID, request) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusCreated, response) + default: + writeMethodNotAllowed(w) + } + return + } + + if len(parts) == 3 && parts[1] == "published-ports" { + if r.Method != http.MethodDelete { + writeMethodNotAllowed(w) + return + } + if err := h.service.DeletePublishedPort(r.Context(), machineID, contracthost.PublishedPortID(parts[2])); err != nil { + writeError(w, statusForError(err), err) + return + } + w.WriteHeader(http.StatusNoContent) + return + } + writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) } diff --git a/internal/model/types.go b/internal/model/types.go index a473af6..5845767 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -9,10 +9,12 @@ import ( type StoragePool string const ( - StoragePoolArtifacts StoragePool = "artifacts" - StoragePoolMachineDisks StoragePool = "machine-disks" - StoragePoolState StoragePool = "state" - StoragePoolUserVolumes StoragePool = "user-volumes" + StoragePoolArtifacts StoragePool = "artifacts" + StoragePoolMachineDisks StoragePool = "machine-disks" + StoragePoolPublishedPorts StoragePool = "published-ports" + StoragePoolSnapshots StoragePool = "snapshots" + StoragePoolState StoragePool = "state" + StoragePoolUserVolumes StoragePool = "user-volumes" ) type ArtifactRecord struct { @@ -54,6 +56,7 @@ type MachineOperation string const ( MachineOperationCreate MachineOperation = "create" + MachineOperationStart MachineOperation = "start" MachineOperationStop MachineOperation = "stop" MachineOperationDelete MachineOperation = "delete" MachineOperationSnapshot MachineOperation = "snapshot" @@ -72,6 +75,16 @@ type SnapshotRecord struct { CreatedAt time.Time } +type PublishedPortRecord struct { + ID contracthost.PublishedPortID + MachineID contracthost.MachineID + Name string + Port uint16 + HostPort uint16 + Protocol contracthost.PortProtocol + CreatedAt time.Time +} + type OperationRecord struct { MachineID contracthost.MachineID Type MachineOperation diff --git a/internal/store/file_store.go b/internal/store/file_store.go index aa00feb..bc21e14 100644 --- a/internal/store/file_store.go +++ b/internal/store/file_store.go @@ -23,10 +23,11 @@ type persistedOperations struct { } type persistedState struct { - Artifacts []model.ArtifactRecord `json:"artifacts"` - Machines []model.MachineRecord `json:"machines"` - Volumes []model.VolumeRecord `json:"volumes"` - Snapshots []model.SnapshotRecord `json:"snapshots"` + Artifacts []model.ArtifactRecord `json:"artifacts"` + Machines []model.MachineRecord `json:"machines"` + Volumes []model.VolumeRecord `json:"volumes"` + Snapshots []model.SnapshotRecord `json:"snapshots"` + PublishedPorts []model.PublishedPortRecord `json:"published_ports"` } func NewFileStore(statePath string, operationsPath string) (*FileStore, error) { @@ -327,6 +328,17 @@ func (s *FileStore) ListSnapshotsByMachine(_ context.Context, machineID contract return result, nil } +func (s *FileStore) ListSnapshots(_ context.Context) ([]model.SnapshotRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + return append([]model.SnapshotRecord(nil), state.Snapshots...), nil +} + func (s *FileStore) DeleteSnapshot(_ context.Context, id contracthost.SnapshotID) error { s.mu.Lock() defer s.mu.Unlock() @@ -342,6 +354,74 @@ func (s *FileStore) DeleteSnapshot(_ context.Context, id contracthost.SnapshotID }) } +func (s *FileStore) CreatePublishedPort(_ context.Context, record model.PublishedPortRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for _, port := range state.PublishedPorts { + if port.ID == record.ID { + return fmt.Errorf("store: published port %q already exists", record.ID) + } + if port.HostPort == record.HostPort { + return fmt.Errorf("store: host port %d already exists", record.HostPort) + } + } + state.PublishedPorts = append(state.PublishedPorts, record) + return nil + }) +} + +func (s *FileStore) GetPublishedPort(_ context.Context, id contracthost.PublishedPortID) (*model.PublishedPortRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + for i := range state.PublishedPorts { + if state.PublishedPorts[i].ID == id { + record := state.PublishedPorts[i] + return &record, nil + } + } + return nil, ErrNotFound +} + +func (s *FileStore) ListPublishedPorts(_ context.Context, machineID contracthost.MachineID) ([]model.PublishedPortRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + result := make([]model.PublishedPortRecord, 0, len(state.PublishedPorts)) + for _, port := range state.PublishedPorts { + if machineID != "" && port.MachineID != machineID { + continue + } + result = append(result, port) + } + return result, nil +} + +func (s *FileStore) DeletePublishedPort(_ context.Context, id contracthost.PublishedPortID) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.PublishedPorts { + if state.PublishedPorts[i].ID == id { + state.PublishedPorts = append(state.PublishedPorts[:i], state.PublishedPorts[i+1:]...) + return nil + } + } + return ErrNotFound + }) +} + func (s *FileStore) readOperations() (*persistedOperations, error) { var operations persistedOperations if err := readJSONFile(s.operationsPath, &operations); err != nil { @@ -422,11 +502,11 @@ func writeJSONFileAtomically(path string, value any) error { return fmt.Errorf("open temp store file %q: %w", tmpPath, err) } if _, err := file.Write(payload); err != nil { - file.Close() + _ = file.Close() return fmt.Errorf("write temp store file %q: %w", tmpPath, err) } if err := file.Sync(); err != nil { - file.Close() + _ = file.Close() return fmt.Errorf("sync temp store file %q: %w", tmpPath, err) } if err := file.Close(); err != nil { @@ -441,7 +521,7 @@ func writeJSONFileAtomically(path string, value any) error { return fmt.Errorf("open store dir for %q: %w", path, err) } if err := dir.Sync(); err != nil { - dir.Close() + _ = dir.Close() return fmt.Errorf("sync store dir for %q: %w", path, err) } if err := dir.Close(); err != nil { @@ -452,10 +532,11 @@ func writeJSONFileAtomically(path string, value any) error { func emptyPersistedState() persistedState { return persistedState{ - Artifacts: []model.ArtifactRecord{}, - Machines: []model.MachineRecord{}, - Volumes: []model.VolumeRecord{}, - Snapshots: []model.SnapshotRecord{}, + Artifacts: []model.ArtifactRecord{}, + Machines: []model.MachineRecord{}, + Volumes: []model.VolumeRecord{}, + Snapshots: []model.SnapshotRecord{}, + PublishedPorts: []model.PublishedPortRecord{}, } } @@ -476,6 +557,9 @@ func normalizeState(state *persistedState) { if state.Snapshots == nil { state.Snapshots = []model.SnapshotRecord{} } + if state.PublishedPorts == nil { + state.PublishedPorts = []model.PublishedPortRecord{} + } } func normalizeOperations(operations *persistedOperations) { diff --git a/internal/store/store.go b/internal/store/store.go index f8e5fdf..a4c55b1 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -29,6 +29,11 @@ type Store interface { DeleteOperation(context.Context, contracthost.MachineID) error CreateSnapshot(context.Context, model.SnapshotRecord) error GetSnapshot(context.Context, contracthost.SnapshotID) (*model.SnapshotRecord, error) + ListSnapshots(context.Context) ([]model.SnapshotRecord, error) ListSnapshotsByMachine(context.Context, contracthost.MachineID) ([]model.SnapshotRecord, error) DeleteSnapshot(context.Context, contracthost.SnapshotID) error + CreatePublishedPort(context.Context, model.PublishedPortRecord) error + GetPublishedPort(context.Context, contracthost.PublishedPortID) (*model.PublishedPortRecord, error) + ListPublishedPorts(context.Context, contracthost.MachineID) ([]model.PublishedPortRecord, error) + DeletePublishedPort(context.Context, contracthost.PublishedPortID) error } diff --git a/main.go b/main.go index 892b9fa..5eca7ea 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "net" "net/http" @@ -10,6 +11,8 @@ import ( "path/filepath" "syscall" + "golang.org/x/sync/errgroup" + appconfig "github.com/getcompanion-ai/computer-host/internal/config" "github.com/getcompanion-ai/computer-host/internal/daemon" "github.com/getcompanion-ai/computer-host/internal/firecracker" @@ -56,19 +59,49 @@ func main() { exit(err) } - listener, err := net.Listen("unix", cfg.SocketPath) + unixListener, err := net.Listen("unix", cfg.SocketPath) if err != nil { exit(err) } - defer listener.Close() - - server := &http.Server{Handler: handler.Routes()} - go func() { - <-ctx.Done() - _ = server.Shutdown(context.Background()) + defer func() { + _ = unixListener.Close() }() - if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + servers := []*http.Server{{Handler: handler.Routes()}} + listeners := []net.Listener{unixListener} + if cfg.HTTPAddr != "" { + httpListener, err := net.Listen("tcp", cfg.HTTPAddr) + if err != nil { + exit(err) + } + defer func() { + _ = httpListener.Close() + }() + servers = append(servers, &http.Server{Handler: handler.Routes()}) + listeners = append(listeners, httpListener) + } + + group, groupCtx := errgroup.WithContext(ctx) + for i := range servers { + server := servers[i] + listener := listeners[i] + group.Go(func() error { + err := server.Serve(listener) + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err + }) + } + group.Go(func() error { + <-groupCtx.Done() + for _, server := range servers { + _ = server.Shutdown(context.Background()) + } + return nil + }) + + if err := group.Wait(); err != nil { exit(err) } }