mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 03:00:42 +00:00
feat(contracts): add published ports, snapshot request, and storage report types
This commit is contained in:
parent
501ae2abd5
commit
26b5d2966d
20 changed files with 893 additions and 81 deletions
30
contract/published_ports.go
Normal file
30
contract/published_ports.go
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
package host
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type PublishedPortID string
|
||||||
|
|
||||||
|
type PublishedPort struct {
|
||||||
|
ID PublishedPortID `json:"id"`
|
||||||
|
MachineID MachineID `json:"machine_id"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Port uint16 `json:"port"`
|
||||||
|
HostPort uint16 `json:"host_port"`
|
||||||
|
Protocol PortProtocol `json:"protocol"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreatePublishedPortRequest struct {
|
||||||
|
PublishedPortID PublishedPortID `json:"published_port_id"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Port uint16 `json:"port"`
|
||||||
|
Protocol PortProtocol `json:"protocol"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreatePublishedPortResponse struct {
|
||||||
|
Port PublishedPort `json:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListPublishedPortsResponse struct {
|
||||||
|
Ports []PublishedPort `json:"ports"`
|
||||||
|
}
|
||||||
|
|
@ -10,6 +10,10 @@ type Snapshot struct {
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateSnapshotRequest struct {
|
||||||
|
SnapshotID SnapshotID `json:"snapshot_id"`
|
||||||
|
}
|
||||||
|
|
||||||
type CreateSnapshotResponse struct {
|
type CreateSnapshotResponse struct {
|
||||||
Snapshot Snapshot `json:"snapshot"`
|
Snapshot Snapshot `json:"snapshot"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,3 +13,43 @@ type Volume struct {
|
||||||
AttachedMachineID *MachineID `json:"attached_machine_id,omitempty"`
|
AttachedMachineID *MachineID `json:"attached_machine_id,omitempty"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StoragePool string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StoragePoolArtifacts StoragePool = "artifacts"
|
||||||
|
StoragePoolMachineDisks StoragePool = "machine-disks"
|
||||||
|
StoragePoolPublishedPort StoragePool = "published-ports"
|
||||||
|
StoragePoolSnapshots StoragePool = "snapshots"
|
||||||
|
StoragePoolState StoragePool = "state"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StoragePoolUsage struct {
|
||||||
|
Pool StoragePool `json:"pool"`
|
||||||
|
Bytes int64 `json:"bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MachineStorageUsage struct {
|
||||||
|
MachineID MachineID `json:"machine_id"`
|
||||||
|
SystemBytes int64 `json:"system_bytes"`
|
||||||
|
UserBytes int64 `json:"user_bytes"`
|
||||||
|
RuntimeBytes int64 `json:"runtime_bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SnapshotStorageUsage struct {
|
||||||
|
SnapshotID SnapshotID `json:"snapshot_id"`
|
||||||
|
Bytes int64 `json:"bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StorageReport struct {
|
||||||
|
GeneratedAt time.Time `json:"generated_at"`
|
||||||
|
TotalBytes int64 `json:"total_bytes"`
|
||||||
|
Pools []StoragePoolUsage `json:"pools,omitempty"`
|
||||||
|
Machines []MachineStorageUsage `json:"machines,omitempty"`
|
||||||
|
Snapshots []SnapshotStorageUsage `json:"snapshots,omitempty"`
|
||||||
|
PublishedPorts int64 `json:"published_ports"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetStorageReportResponse struct {
|
||||||
|
Report StorageReport `json:"report"`
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ type Config struct {
|
||||||
SnapshotsDir string
|
SnapshotsDir string
|
||||||
RuntimeDir string
|
RuntimeDir string
|
||||||
SocketPath string
|
SocketPath string
|
||||||
|
HTTPAddr string
|
||||||
EgressInterface string
|
EgressInterface string
|
||||||
FirecrackerBinaryPath string
|
FirecrackerBinaryPath string
|
||||||
JailerBinaryPath string
|
JailerBinaryPath string
|
||||||
|
|
@ -38,6 +39,7 @@ func Load() (Config, error) {
|
||||||
SnapshotsDir: filepath.Join(rootDir, "snapshots"),
|
SnapshotsDir: filepath.Join(rootDir, "snapshots"),
|
||||||
RuntimeDir: filepath.Join(rootDir, "runtime"),
|
RuntimeDir: filepath.Join(rootDir, "runtime"),
|
||||||
SocketPath: filepath.Join(rootDir, defaultSocketName),
|
SocketPath: filepath.Join(rootDir, defaultSocketName),
|
||||||
|
HTTPAddr: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_HTTP_ADDR")),
|
||||||
EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")),
|
EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")),
|
||||||
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
|
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
|
||||||
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
|
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package daemon
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -45,6 +46,9 @@ type Daemon struct {
|
||||||
locksMu sync.Mutex
|
locksMu sync.Mutex
|
||||||
machineLocks map[contracthost.MachineID]*sync.Mutex
|
machineLocks map[contracthost.MachineID]*sync.Mutex
|
||||||
artifactLocks map[string]*sync.Mutex
|
artifactLocks map[string]*sync.Mutex
|
||||||
|
|
||||||
|
publishedPortsMu sync.Mutex
|
||||||
|
publishedPortListeners map[contracthost.PublishedPortID]net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, error) {
|
func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, error) {
|
||||||
|
|
@ -69,6 +73,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
||||||
reconfigureGuestIdentity: nil,
|
reconfigureGuestIdentity: nil,
|
||||||
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
||||||
artifactLocks: make(map[string]*sync.Mutex),
|
artifactLocks: make(map[string]*sync.Mutex),
|
||||||
|
publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener),
|
||||||
}
|
}
|
||||||
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
|
||||||
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -73,9 +73,13 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||||
defer sshListener.Close()
|
defer func() {
|
||||||
|
_ = sshListener.Close()
|
||||||
|
}()
|
||||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||||
defer vncListener.Close()
|
defer func() {
|
||||||
|
_ = vncListener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
startedAt := time.Unix(1700000005, 0).UTC()
|
startedAt := time.Unix(1700000005, 0).UTC()
|
||||||
runtime := &fakeRuntime{
|
runtime := &fakeRuntime{
|
||||||
|
|
@ -339,9 +343,13 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sshListener := listenTestPort(t, int(defaultSSHPort))
|
sshListener := listenTestPort(t, int(defaultSSHPort))
|
||||||
defer sshListener.Close()
|
defer func() {
|
||||||
|
_ = sshListener.Close()
|
||||||
|
}()
|
||||||
vncListener := listenTestPort(t, int(defaultVNCPort))
|
vncListener := listenTestPort(t, int(defaultVNCPort))
|
||||||
defer vncListener.Close()
|
defer func() {
|
||||||
|
_ = vncListener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
startedAt := time.Unix(1700000099, 0).UTC()
|
startedAt := time.Unix(1700000099, 0).UTC()
|
||||||
runtime := &fakeRuntime{
|
runtime := &fakeRuntime{
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,9 @@ func cloneFile(source string, target string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("open source file %q: %w", source, err)
|
return fmt.Errorf("open source file %q: %w", source, err)
|
||||||
}
|
}
|
||||||
defer sourceFile.Close()
|
defer func() {
|
||||||
|
_ = sourceFile.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
sourceInfo, err := sourceFile.Stat()
|
sourceInfo, err := sourceFile.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -67,15 +69,15 @@ func cloneFile(source string, target string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := writeSparseFile(targetFile, sourceFile); err != nil {
|
if _, err := writeSparseFile(targetFile, sourceFile); err != nil {
|
||||||
targetFile.Close()
|
_ = targetFile.Close()
|
||||||
return fmt.Errorf("copy %q to %q: %w", source, tmpPath, err)
|
return fmt.Errorf("copy %q to %q: %w", source, tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := targetFile.Truncate(sourceInfo.Size()); err != nil {
|
if err := targetFile.Truncate(sourceInfo.Size()); err != nil {
|
||||||
targetFile.Close()
|
_ = targetFile.Close()
|
||||||
return fmt.Errorf("truncate target file %q: %w", tmpPath, err)
|
return fmt.Errorf("truncate target file %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := targetFile.Sync(); err != nil {
|
if err := targetFile.Sync(); err != nil {
|
||||||
targetFile.Close()
|
_ = targetFile.Close()
|
||||||
return fmt.Errorf("sync target file %q: %w", tmpPath, err)
|
return fmt.Errorf("sync target file %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := targetFile.Close(); err != nil {
|
if err := targetFile.Close(); err != nil {
|
||||||
|
|
@ -108,7 +110,9 @@ func downloadFile(ctx context.Context, rawURL string, path string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("download %q: %w", rawURL, err)
|
return fmt.Errorf("download %q: %w", rawURL, err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer func() {
|
||||||
|
_ = response.Body.Close()
|
||||||
|
}()
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("download %q: status %d", rawURL, response.StatusCode)
|
return fmt.Errorf("download %q: status %d", rawURL, response.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
@ -121,15 +125,15 @@ func downloadFile(ctx context.Context, rawURL string, path string) error {
|
||||||
|
|
||||||
size, err := writeSparseFile(file, response.Body)
|
size, err := writeSparseFile(file, response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
return fmt.Errorf("write download target %q: %w", tmpPath, err)
|
return fmt.Errorf("write download target %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := file.Truncate(size); err != nil {
|
if err := file.Truncate(size); err != nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
return fmt.Errorf("truncate download target %q: %w", tmpPath, err)
|
return fmt.Errorf("truncate download target %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := file.Sync(); err != nil {
|
if err := file.Sync(); err != nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
return fmt.Errorf("sync download target %q: %w", tmpPath, err)
|
return fmt.Errorf("sync download target %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := file.Close(); err != nil {
|
if err := file.Close(); err != nil {
|
||||||
|
|
@ -267,7 +271,9 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create guest config staging dir: %w", err)
|
return fmt.Errorf("create guest config staging dir: %w", err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(stagingDir)
|
defer func() {
|
||||||
|
_ = os.RemoveAll(stagingDir)
|
||||||
|
}()
|
||||||
|
|
||||||
if len(config.AuthorizedKeys) > 0 {
|
if len(config.AuthorizedKeys) > 0 {
|
||||||
authorizedKeysPath := filepath.Join(stagingDir, "authorized_keys")
|
authorizedKeysPath := filepath.Join(stagingDir, "authorized_keys")
|
||||||
|
|
@ -306,7 +312,9 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create machine identity staging dir: %w", err)
|
return fmt.Errorf("create machine identity staging dir: %w", err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(stagingDir)
|
defer func() {
|
||||||
|
_ = os.RemoveAll(stagingDir)
|
||||||
|
}()
|
||||||
|
|
||||||
identityFiles := map[string]string{
|
identityFiles := map[string]string{
|
||||||
"/etc/microagent/machine-name": machineName + "\n",
|
"/etc/microagent/machine-name": machineName + "\n",
|
||||||
|
|
@ -368,6 +376,18 @@ func machineToContract(record model.MachineRecord) contracthost.Machine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func publishedPortToContract(record model.PublishedPortRecord) contracthost.PublishedPort {
|
||||||
|
return contracthost.PublishedPort{
|
||||||
|
ID: record.ID,
|
||||||
|
MachineID: record.MachineID,
|
||||||
|
Name: record.Name,
|
||||||
|
Port: record.Port,
|
||||||
|
HostPort: record.HostPort,
|
||||||
|
Protocol: record.Protocol,
|
||||||
|
CreatedAt: record.CreatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState {
|
func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState {
|
||||||
phase := firecracker.PhaseStopped
|
phase := firecracker.PhaseStopped
|
||||||
switch record.Phase {
|
switch record.Phase {
|
||||||
|
|
@ -426,6 +446,13 @@ func validateMachineID(machineID contracthost.MachineID) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateSnapshotID(snapshotID contracthost.SnapshotID) error {
|
||||||
|
if strings.TrimSpace(string(snapshotID)) == "" {
|
||||||
|
return fmt.Errorf("snapshot_id is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func validateDownloadURL(field string, raw string) error {
|
func validateDownloadURL(field string, raw string) error {
|
||||||
value := strings.TrimSpace(raw)
|
value := strings.TrimSpace(raw)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
|
|
@ -450,7 +477,7 @@ func syncDir(path string) error {
|
||||||
return fmt.Errorf("open dir %q: %w", path, err)
|
return fmt.Errorf("open dir %q: %w", path, err)
|
||||||
}
|
}
|
||||||
if err := dir.Sync(); err != nil {
|
if err := dir.Sync(); err != nil {
|
||||||
dir.Close()
|
_ = dir.Close()
|
||||||
return fmt.Errorf("sync dir %q: %w", path, err)
|
return fmt.Errorf("sync dir %q: %w", path, err)
|
||||||
}
|
}
|
||||||
if err := dir.Close(); err != nil {
|
if err := dir.Close(); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -19,15 +19,15 @@ func TestCloneFilePreservesSparseDiskUsage(t *testing.T) {
|
||||||
t.Fatalf("open source file: %v", err)
|
t.Fatalf("open source file: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := sourceFile.Write([]byte("head")); err != nil {
|
if _, err := sourceFile.Write([]byte("head")); err != nil {
|
||||||
sourceFile.Close()
|
_ = sourceFile.Close()
|
||||||
t.Fatalf("write source prefix: %v", err)
|
t.Fatalf("write source prefix: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := sourceFile.Seek(32<<20, io.SeekStart); err != nil {
|
if _, err := sourceFile.Seek(32<<20, io.SeekStart); err != nil {
|
||||||
sourceFile.Close()
|
_ = sourceFile.Close()
|
||||||
t.Fatalf("seek source hole: %v", err)
|
t.Fatalf("seek source hole: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := sourceFile.Write([]byte("tail")); err != nil {
|
if _, err := sourceFile.Write([]byte("tail")); err != nil {
|
||||||
sourceFile.Close()
|
_ = sourceFile.Close()
|
||||||
t.Fatalf("write source suffix: %v", err)
|
t.Fatalf("write source suffix: %v", err)
|
||||||
}
|
}
|
||||||
if err := sourceFile.Close(); err != nil {
|
if err := sourceFile.Close(); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,88 @@ func (d *Daemon) ListMachines(ctx context.Context) (*contracthost.ListMachinesRe
|
||||||
return &contracthost.ListMachinesResponse{Machines: machines}, nil
|
return &contracthost.ListMachinesResponse{Machines: machines}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
|
||||||
|
unlock := d.lockMachine(id)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
record, err := d.store.GetMachine(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if record.Phase == contracthost.MachinePhaseRunning {
|
||||||
|
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||||
|
}
|
||||||
|
if record.Phase != contracthost.MachinePhaseStopped {
|
||||||
|
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
|
||||||
|
MachineID: id,
|
||||||
|
Type: model.MachineOperationStart,
|
||||||
|
StartedAt: time.Now().UTC(),
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clearOperation := false
|
||||||
|
defer func() {
|
||||||
|
if clearOperation {
|
||||||
|
_ = d.store.DeleteOperation(context.Background(), id)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
artifact, err := d.store.GetArtifact(ctx, record.Artifact)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
usedNetworks, err := d.listRunningNetworks(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state, err := d.runtime.Boot(ctx, spec, usedNetworks)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ports := defaultMachinePorts()
|
||||||
|
if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil {
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
record.RuntimeHost = state.RuntimeHost
|
||||||
|
record.TapDevice = state.TapName
|
||||||
|
record.Ports = ports
|
||||||
|
record.Phase = contracthost.MachinePhaseRunning
|
||||||
|
record.Error = ""
|
||||||
|
record.PID = state.PID
|
||||||
|
record.SocketPath = state.SocketPath
|
||||||
|
record.StartedAt = state.StartedAt
|
||||||
|
if err := d.store.UpdateMachine(ctx, *record); err != nil {
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||||
|
d.stopPublishedPortsForMachine(id)
|
||||||
|
_ = d.runtime.Delete(context.Background(), *state)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clearOperation = true
|
||||||
|
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Daemon) StopMachine(ctx context.Context, id contracthost.MachineID) error {
|
func (d *Daemon) StopMachine(ctx context.Context, id contracthost.MachineID) error {
|
||||||
unlock := d.lockMachine(id)
|
unlock := d.lockMachine(id)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
@ -120,6 +202,10 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
|
||||||
if err := d.reconcileCreate(ctx, operation.MachineID); err != nil {
|
if err := d.reconcileCreate(ctx, operation.MachineID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case model.MachineOperationStart:
|
||||||
|
if err := d.reconcileStart(ctx, operation.MachineID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case model.MachineOperationStop:
|
case model.MachineOperationStop:
|
||||||
if err := d.reconcileStop(ctx, operation.MachineID); err != nil {
|
if err := d.reconcileStop(ctx, operation.MachineID); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -149,6 +235,13 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
|
||||||
if _, err := d.reconcileMachine(ctx, record.ID); err != nil {
|
if _, err := d.reconcileMachine(ctx, record.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if record.Phase == contracthost.MachinePhaseRunning {
|
||||||
|
if err := d.ensurePublishedPortsForMachine(ctx, record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -218,6 +311,26 @@ func (d *Daemon) reconcileStop(ctx context.Context, machineID contracthost.Machi
|
||||||
return d.store.DeleteOperation(ctx, machineID)
|
return d.store.DeleteOperation(ctx, machineID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.MachineID) error {
|
||||||
|
record, err := d.store.GetMachine(ctx, machineID)
|
||||||
|
if err == store.ErrNotFound {
|
||||||
|
return d.store.DeleteOperation(ctx, machineID)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if record.Phase == contracthost.MachinePhaseRunning {
|
||||||
|
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.store.DeleteOperation(ctx, machineID)
|
||||||
|
}
|
||||||
|
if _, err := d.StartMachine(ctx, machineID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.store.DeleteOperation(ctx, machineID)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Daemon) reconcileDelete(ctx context.Context, machineID contracthost.MachineID) error {
|
func (d *Daemon) reconcileDelete(ctx context.Context, machineID contracthost.MachineID) error {
|
||||||
record, err := d.store.GetMachine(ctx, machineID)
|
record, err := d.store.GetMachine(ctx, machineID)
|
||||||
if err == store.ErrNotFound {
|
if err == store.ErrNotFound {
|
||||||
|
|
@ -266,6 +379,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
||||||
if err := d.runtime.Delete(ctx, *state); err != nil {
|
if err := d.runtime.Delete(ctx, *state); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
record.Phase = contracthost.MachinePhaseFailed
|
record.Phase = contracthost.MachinePhaseFailed
|
||||||
record.Error = state.Error
|
record.Error = state.Error
|
||||||
record.PID = 0
|
record.PID = 0
|
||||||
|
|
@ -280,9 +394,15 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||||
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if ports, err := d.store.ListPublishedPorts(ctx, record.ID); err == nil {
|
||||||
|
for _, port := range ports {
|
||||||
|
_ = d.store.DeletePublishedPort(ctx, port.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := d.detachVolumesForMachine(ctx, record.ID); err != nil {
|
if err := d.detachVolumesForMachine(ctx, record.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -304,6 +424,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
|
||||||
|
d.stopPublishedPortsForMachine(record.ID)
|
||||||
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
246
internal/daemon/published_ports.go
Normal file
246
internal/daemon/published_ports.go
Normal file
|
|
@ -0,0 +1,246 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
||||||
|
"github.com/getcompanion-ai/computer-host/internal/store"
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
minPublishedHostPort = uint16(20000)
|
||||||
|
maxPublishedHostPort = uint16(39999)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (d *Daemon) CreatePublishedPort(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error) {
|
||||||
|
if strings.TrimSpace(string(req.PublishedPortID)) == "" {
|
||||||
|
return nil, fmt.Errorf("published_port_id is required")
|
||||||
|
}
|
||||||
|
if req.Port == 0 {
|
||||||
|
return nil, fmt.Errorf("port must be greater than zero")
|
||||||
|
}
|
||||||
|
if req.Protocol == "" {
|
||||||
|
req.Protocol = contracthost.PortProtocolTCP
|
||||||
|
}
|
||||||
|
if req.Protocol != contracthost.PortProtocolTCP {
|
||||||
|
return nil, fmt.Errorf("unsupported protocol %q", req.Protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
unlock := d.lockMachine(machineID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
record, err := d.store.GetMachine(ctx, machineID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if record.Phase != contracthost.MachinePhaseRunning {
|
||||||
|
return nil, fmt.Errorf("machine %q is not running", machineID)
|
||||||
|
}
|
||||||
|
if _, err := d.store.GetPublishedPort(ctx, req.PublishedPortID); err == nil {
|
||||||
|
return nil, fmt.Errorf("published port %q already exists", req.PublishedPortID)
|
||||||
|
} else if err != nil && err != store.ErrNotFound {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPort, err := d.allocatePublishedHostPort(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
published := model.PublishedPortRecord{
|
||||||
|
ID: req.PublishedPortID,
|
||||||
|
MachineID: machineID,
|
||||||
|
Name: strings.TrimSpace(req.Name),
|
||||||
|
Port: req.Port,
|
||||||
|
HostPort: hostPort,
|
||||||
|
Protocol: req.Protocol,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
if err := d.startPublishedPortProxy(published, record.RuntimeHost); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
storeCreated := false
|
||||||
|
defer func() {
|
||||||
|
if !storeCreated {
|
||||||
|
d.stopPublishedPortProxy(req.PublishedPortID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := d.store.CreatePublishedPort(ctx, published); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
storeCreated = true
|
||||||
|
return &contracthost.CreatePublishedPortResponse{Port: publishedPortToContract(published)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) ListPublishedPorts(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListPublishedPortsResponse, error) {
|
||||||
|
ports, err := d.store.ListPublishedPorts(ctx, machineID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response := &contracthost.ListPublishedPortsResponse{Ports: make([]contracthost.PublishedPort, 0, len(ports))}
|
||||||
|
for _, port := range ports {
|
||||||
|
response.Ports = append(response.Ports, publishedPortToContract(port))
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) DeletePublishedPort(ctx context.Context, machineID contracthost.MachineID, portID contracthost.PublishedPortID) error {
|
||||||
|
unlock := d.lockMachine(machineID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
record, err := d.store.GetPublishedPort(ctx, portID)
|
||||||
|
if err != nil {
|
||||||
|
if err == store.ErrNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if record.MachineID != machineID {
|
||||||
|
return fmt.Errorf("published port %q does not belong to machine %q", portID, machineID)
|
||||||
|
}
|
||||||
|
d.stopPublishedPortProxy(portID)
|
||||||
|
return d.store.DeletePublishedPort(ctx, portID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) ensurePublishedPortsForMachine(ctx context.Context, machine model.MachineRecord) error {
|
||||||
|
if machine.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(machine.RuntimeHost) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ports, err := d.store.ListPublishedPorts(ctx, machine.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, port := range ports {
|
||||||
|
if err := d.startPublishedPortProxy(port, machine.RuntimeHost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) stopPublishedPortsForMachine(machineID contracthost.MachineID) {
|
||||||
|
ports, err := d.store.ListPublishedPorts(context.Background(), machineID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, port := range ports {
|
||||||
|
d.stopPublishedPortProxy(port.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) allocatePublishedHostPort(ctx context.Context) (uint16, error) {
|
||||||
|
ports, err := d.store.ListPublishedPorts(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
used := make(map[uint16]struct{}, len(ports))
|
||||||
|
for _, port := range ports {
|
||||||
|
used[port.HostPort] = struct{}{}
|
||||||
|
}
|
||||||
|
for hostPort := minPublishedHostPort; hostPort <= maxPublishedHostPort; hostPort++ {
|
||||||
|
if _, exists := used[hostPort]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return hostPort, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("no published host ports are available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtimeHost string) error {
|
||||||
|
targetHost := strings.TrimSpace(runtimeHost)
|
||||||
|
if targetHost == "" {
|
||||||
|
return fmt.Errorf("runtime host is required for published port %q", port.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.publishedPortsMu.Lock()
|
||||||
|
if _, exists := d.publishedPortListeners[port.ID]; exists {
|
||||||
|
d.publishedPortsMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(port.HostPort)))
|
||||||
|
if err != nil {
|
||||||
|
d.publishedPortsMu.Unlock()
|
||||||
|
return fmt.Errorf("listen on host port %d: %w", port.HostPort, err)
|
||||||
|
}
|
||||||
|
d.publishedPortListeners[port.ID] = listener
|
||||||
|
d.publishedPortsMu.Unlock()
|
||||||
|
|
||||||
|
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port)))
|
||||||
|
go d.servePublishedPortProxy(port.ID, listener, targetAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) {
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if isClosedNetworkError(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go proxyPublishedPortConnection(conn, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyPublishedPortConnection(source net.Conn, targetAddr string) {
|
||||||
|
defer func() {
|
||||||
|
_ = source.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
target, err := net.DialTimeout("tcp", targetAddr, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = target.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan struct{}, 1)
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(target, source)
|
||||||
|
closeWrite(target)
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, _ = io.Copy(source, target)
|
||||||
|
closeWrite(source)
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWrite(conn net.Conn) {
|
||||||
|
type closeWriter interface {
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
if value, ok := conn.(closeWriter); ok {
|
||||||
|
_ = value.CloseWrite()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedNetworkError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
message := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(message, "use of closed network connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) stopPublishedPortProxy(portID contracthost.PublishedPortID) {
|
||||||
|
d.publishedPortsMu.Lock()
|
||||||
|
listener, ok := d.publishedPortListeners[portID]
|
||||||
|
if ok {
|
||||||
|
delete(d.publishedPortListeners, portID)
|
||||||
|
}
|
||||||
|
d.publishedPortsMu.Unlock()
|
||||||
|
if ok {
|
||||||
|
_ = listener.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,8 +2,6 @@ package daemon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
|
@ -18,10 +16,14 @@ import (
|
||||||
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error) {
|
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
|
||||||
unlock := d.lockMachine(machineID)
|
unlock := d.lockMachine(machineID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
|
if err := validateSnapshotID(req.SnapshotID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
record, err := d.store.GetMachine(ctx, machineID)
|
record, err := d.store.GetMachine(ctx, machineID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -39,7 +41,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
snapshotID := contracthost.SnapshotID(generateID())
|
snapshotID := req.SnapshotID
|
||||||
|
|
||||||
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
|
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
|
||||||
MachineID: machineID,
|
MachineID: machineID,
|
||||||
|
|
@ -356,14 +358,6 @@ func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecra
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateID() string {
|
|
||||||
b := make([]byte, 16)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
panic(fmt.Sprintf("generate id: %v", err))
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// moveFile copies src to dst then removes src. Works across filesystem boundaries
|
// moveFile copies src to dst then removes src. Works across filesystem boundaries
|
||||||
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
|
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
|
||||||
func moveFile(src, dst string) error {
|
func moveFile(src, dst string) error {
|
||||||
|
|
@ -371,7 +365,9 @@ func moveFile(src, dst string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer in.Close()
|
defer func() {
|
||||||
|
_ = in.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
out, err := os.Create(dst)
|
out, err := os.Create(dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -379,7 +375,7 @@ func moveFile(src, dst string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := io.Copy(out, in); err != nil {
|
if _, err := io.Copy(out, in); err != nil {
|
||||||
out.Close()
|
_ = out.Close()
|
||||||
_ = os.Remove(dst)
|
_ = os.Remove(dst)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
141
internal/daemon/storage_report.go
Normal file
141
internal/daemon/storage_report.go
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (d *Daemon) GetStorageReport(ctx context.Context) (*contracthost.GetStorageReportResponse, error) {
|
||||||
|
volumes, err := d.store.ListVolumes(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
snapshots, err := d.store.ListSnapshots(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
publishedPorts, err := d.store.ListPublishedPorts(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pools := make([]contracthost.StoragePoolUsage, 0, 5)
|
||||||
|
totalBytes := int64(0)
|
||||||
|
addPool := func(pool contracthost.StoragePool, path string) error {
|
||||||
|
bytes, err := directorySize(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pools = append(pools, contracthost.StoragePoolUsage{Pool: pool, Bytes: bytes})
|
||||||
|
totalBytes += bytes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pool := range []struct {
|
||||||
|
name contracthost.StoragePool
|
||||||
|
path string
|
||||||
|
}{
|
||||||
|
{name: contracthost.StoragePoolArtifacts, path: d.config.ArtifactsDir},
|
||||||
|
{name: contracthost.StoragePoolMachineDisks, path: d.config.MachineDisksDir},
|
||||||
|
{name: contracthost.StoragePoolSnapshots, path: d.config.SnapshotsDir},
|
||||||
|
{name: contracthost.StoragePoolState, path: filepath.Dir(d.config.StatePath)},
|
||||||
|
} {
|
||||||
|
if err := addPool(pool.name, pool.path); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
machineUsage := make([]contracthost.MachineStorageUsage, 0, len(volumes))
|
||||||
|
for _, volume := range volumes {
|
||||||
|
if volume.AttachedMachineID == nil || volume.Kind != contracthost.VolumeKindSystem {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bytes, err := fileSize(volume.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
machineUsage = append(machineUsage, contracthost.MachineStorageUsage{
|
||||||
|
MachineID: *volume.AttachedMachineID,
|
||||||
|
SystemBytes: bytes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshotUsage := make([]contracthost.SnapshotStorageUsage, 0, len(snapshots))
|
||||||
|
for _, snapshot := range snapshots {
|
||||||
|
bytes, err := fileSize(snapshot.MemFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateBytes, err := fileSize(snapshot.StateFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bytes += stateBytes
|
||||||
|
for _, diskPath := range snapshot.DiskPaths {
|
||||||
|
diskBytes, err := fileSize(diskPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bytes += diskBytes
|
||||||
|
}
|
||||||
|
snapshotUsage = append(snapshotUsage, contracthost.SnapshotStorageUsage{
|
||||||
|
SnapshotID: snapshot.ID,
|
||||||
|
Bytes: bytes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &contracthost.GetStorageReportResponse{
|
||||||
|
Report: contracthost.StorageReport{
|
||||||
|
GeneratedAt: time.Now().UTC(),
|
||||||
|
TotalBytes: totalBytes,
|
||||||
|
Pools: pools,
|
||||||
|
Machines: machineUsage,
|
||||||
|
Snapshots: snapshotUsage,
|
||||||
|
PublishedPorts: int64(len(publishedPorts)),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func directorySize(root string) (int64, error) {
|
||||||
|
if root == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(root); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("stat %q: %w", root, err)
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
if err := filepath.WalkDir(root, func(path string, entry fs.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if entry.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
total += info.Size()
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return 0, fmt.Errorf("walk %q: %w", root, err)
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileSize(path string) (int64, error) {
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("stat %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return info.Size(), nil
|
||||||
|
}
|
||||||
|
|
@ -231,7 +231,9 @@ func (c *apiClient) do(ctx context.Context, method string, endpoint string, inpu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("do %s %s via %q: %w", method, endpoint, c.socketPath, err)
|
return fmt.Errorf("do %s %s via %q: %w", method, endpoint, c.socketPath, err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer func() {
|
||||||
|
_ = response.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
if response.StatusCode != wantStatus {
|
if response.StatusCode != wantStatus {
|
||||||
return decodeFirecrackerError(method, endpoint, response)
|
return decodeFirecrackerError(method, endpoint, response)
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,6 @@ func waitForSocket(ctx context.Context, client *apiClient, socketPath string) er
|
||||||
lastPingErr = err
|
lastPingErr = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
lastPingErr = nil
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -231,13 +230,6 @@ func waitForPIDFile(ctx context.Context, pidFilePath string) (int, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hostVSockPath(paths machinePaths, spec MachineSpec) string {
|
|
||||||
if spec.Vsock == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return filepath.Join(paths.ChrootRootDir, defaultFirecrackerSocketDir, filepath.Base(strings.TrimSpace(spec.Vsock.Path)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func jailedVSockPath(spec MachineSpec) string {
|
func jailedVSockPath(spec MachineSpec) string {
|
||||||
if spec.Vsock == nil {
|
if spec.Vsock == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
|
|
@ -67,18 +67,3 @@ func buildMachinePaths(rootDir string, id MachineID, firecrackerBinaryPath strin
|
||||||
func procSocketPath(pid int) string {
|
func procSocketPath(pid int) string {
|
||||||
return filepath.Join("/proc", strconv.Itoa(pid), "root", defaultFirecrackerSocketDir, defaultFirecrackerSocketName)
|
return filepath.Join("/proc", strconv.Itoa(pid), "root", defaultFirecrackerSocketDir, defaultFirecrackerSocketName)
|
||||||
}
|
}
|
||||||
|
|
||||||
type snapshotPaths struct {
|
|
||||||
BaseDir string
|
|
||||||
MemFilePath string
|
|
||||||
StateFilePath string
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildSnapshotPaths(rootDir string, id string) snapshotPaths {
|
|
||||||
baseDir := filepath.Join(rootDir, "snapshots", id)
|
|
||||||
return snapshotPaths{
|
|
||||||
BaseDir: baseDir,
|
|
||||||
MemFilePath: filepath.Join(baseDir, "memory.bin"),
|
|
||||||
StateFilePath: filepath.Join(baseDir, "vmstate.bin"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,19 @@ type Service interface {
|
||||||
CreateMachine(context.Context, contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error)
|
CreateMachine(context.Context, contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error)
|
||||||
GetMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error)
|
GetMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error)
|
||||||
ListMachines(context.Context) (*contracthost.ListMachinesResponse, error)
|
ListMachines(context.Context) (*contracthost.ListMachinesResponse, error)
|
||||||
|
StartMachine(context.Context, contracthost.MachineID) (*contracthost.GetMachineResponse, error)
|
||||||
StopMachine(context.Context, contracthost.MachineID) error
|
StopMachine(context.Context, contracthost.MachineID) error
|
||||||
DeleteMachine(context.Context, contracthost.MachineID) error
|
DeleteMachine(context.Context, contracthost.MachineID) error
|
||||||
Health(context.Context) (*contracthost.HealthResponse, error)
|
Health(context.Context) (*contracthost.HealthResponse, error)
|
||||||
CreateSnapshot(context.Context, contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error)
|
GetStorageReport(context.Context) (*contracthost.GetStorageReportResponse, error)
|
||||||
|
CreateSnapshot(context.Context, contracthost.MachineID, contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error)
|
||||||
ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error)
|
ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error)
|
||||||
GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error)
|
GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error)
|
||||||
DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error
|
DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error
|
||||||
RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error)
|
RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error)
|
||||||
|
CreatePublishedPort(context.Context, contracthost.MachineID, contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error)
|
||||||
|
ListPublishedPorts(context.Context, contracthost.MachineID) (*contracthost.ListPublishedPortsResponse, error)
|
||||||
|
DeletePublishedPort(context.Context, contracthost.MachineID, contracthost.PublishedPortID) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
|
|
@ -38,6 +43,7 @@ func New(service Service) (*Handler, error) {
|
||||||
func (h *Handler) Routes() http.Handler {
|
func (h *Handler) Routes() http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/health", h.handleHealth)
|
mux.HandleFunc("/health", h.handleHealth)
|
||||||
|
mux.HandleFunc("/storage/report", h.handleStorageReport)
|
||||||
mux.HandleFunc("/machines", h.handleMachines)
|
mux.HandleFunc("/machines", h.handleMachines)
|
||||||
mux.HandleFunc("/machines/", h.handleMachine)
|
mux.HandleFunc("/machines/", h.handleMachine)
|
||||||
mux.HandleFunc("/snapshots/", h.handleSnapshot)
|
mux.HandleFunc("/snapshots/", h.handleSnapshot)
|
||||||
|
|
@ -57,6 +63,19 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
writeJSON(w, http.StatusOK, response)
|
writeJSON(w, http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleStorageReport(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeMethodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response, err := h.service.GetStorageReport(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleMachines(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) handleMachines(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
|
|
@ -126,6 +145,20 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(parts) == 2 && parts[1] == "start" {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeMethodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response, err := h.service.StartMachine(r.Context(), machineID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, statusForError(err), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, response)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if len(parts) == 2 && parts[1] == "snapshots" {
|
if len(parts) == 2 && parts[1] == "snapshots" {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
|
|
@ -136,7 +169,12 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
writeJSON(w, http.StatusOK, response)
|
writeJSON(w, http.StatusOK, response)
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
response, err := h.service.CreateSnapshot(r.Context(), machineID)
|
var request contracthost.CreateSnapshotRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response, err := h.service.CreateSnapshot(r.Context(), machineID, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, statusForError(err), err)
|
writeError(w, statusForError(err), err)
|
||||||
return
|
return
|
||||||
|
|
@ -148,6 +186,46 @@ func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(parts) == 2 && parts[1] == "published-ports" {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
response, err := h.service.ListPublishedPorts(r.Context(), machineID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, statusForError(err), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, response)
|
||||||
|
case http.MethodPost:
|
||||||
|
var request contracthost.CreatePublishedPortRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response, err := h.service.CreatePublishedPort(r.Context(), machineID, request)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, statusForError(err), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusCreated, response)
|
||||||
|
default:
|
||||||
|
writeMethodNotAllowed(w)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 3 && parts[1] == "published-ports" {
|
||||||
|
if r.Method != http.MethodDelete {
|
||||||
|
writeMethodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.service.DeletePublishedPort(r.Context(), machineID, contracthost.PublishedPortID(parts[2])); err != nil {
|
||||||
|
writeError(w, statusForError(err), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
writeError(w, http.StatusNotFound, fmt.Errorf("route not found"))
|
writeError(w, http.StatusNotFound, fmt.Errorf("route not found"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,12 @@ import (
|
||||||
type StoragePool string
|
type StoragePool string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
StoragePoolArtifacts StoragePool = "artifacts"
|
StoragePoolArtifacts StoragePool = "artifacts"
|
||||||
StoragePoolMachineDisks StoragePool = "machine-disks"
|
StoragePoolMachineDisks StoragePool = "machine-disks"
|
||||||
StoragePoolState StoragePool = "state"
|
StoragePoolPublishedPorts StoragePool = "published-ports"
|
||||||
StoragePoolUserVolumes StoragePool = "user-volumes"
|
StoragePoolSnapshots StoragePool = "snapshots"
|
||||||
|
StoragePoolState StoragePool = "state"
|
||||||
|
StoragePoolUserVolumes StoragePool = "user-volumes"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ArtifactRecord struct {
|
type ArtifactRecord struct {
|
||||||
|
|
@ -54,6 +56,7 @@ type MachineOperation string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MachineOperationCreate MachineOperation = "create"
|
MachineOperationCreate MachineOperation = "create"
|
||||||
|
MachineOperationStart MachineOperation = "start"
|
||||||
MachineOperationStop MachineOperation = "stop"
|
MachineOperationStop MachineOperation = "stop"
|
||||||
MachineOperationDelete MachineOperation = "delete"
|
MachineOperationDelete MachineOperation = "delete"
|
||||||
MachineOperationSnapshot MachineOperation = "snapshot"
|
MachineOperationSnapshot MachineOperation = "snapshot"
|
||||||
|
|
@ -72,6 +75,16 @@ type SnapshotRecord struct {
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PublishedPortRecord struct {
|
||||||
|
ID contracthost.PublishedPortID
|
||||||
|
MachineID contracthost.MachineID
|
||||||
|
Name string
|
||||||
|
Port uint16
|
||||||
|
HostPort uint16
|
||||||
|
Protocol contracthost.PortProtocol
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type OperationRecord struct {
|
type OperationRecord struct {
|
||||||
MachineID contracthost.MachineID
|
MachineID contracthost.MachineID
|
||||||
Type MachineOperation
|
Type MachineOperation
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,11 @@ type persistedOperations struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type persistedState struct {
|
type persistedState struct {
|
||||||
Artifacts []model.ArtifactRecord `json:"artifacts"`
|
Artifacts []model.ArtifactRecord `json:"artifacts"`
|
||||||
Machines []model.MachineRecord `json:"machines"`
|
Machines []model.MachineRecord `json:"machines"`
|
||||||
Volumes []model.VolumeRecord `json:"volumes"`
|
Volumes []model.VolumeRecord `json:"volumes"`
|
||||||
Snapshots []model.SnapshotRecord `json:"snapshots"`
|
Snapshots []model.SnapshotRecord `json:"snapshots"`
|
||||||
|
PublishedPorts []model.PublishedPortRecord `json:"published_ports"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFileStore(statePath string, operationsPath string) (*FileStore, error) {
|
func NewFileStore(statePath string, operationsPath string) (*FileStore, error) {
|
||||||
|
|
@ -327,6 +328,17 @@ func (s *FileStore) ListSnapshotsByMachine(_ context.Context, machineID contract
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) ListSnapshots(_ context.Context) ([]model.SnapshotRecord, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
state, err := s.readState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return append([]model.SnapshotRecord(nil), state.Snapshots...), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) DeleteSnapshot(_ context.Context, id contracthost.SnapshotID) error {
|
func (s *FileStore) DeleteSnapshot(_ context.Context, id contracthost.SnapshotID) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
@ -342,6 +354,74 @@ func (s *FileStore) DeleteSnapshot(_ context.Context, id contracthost.SnapshotID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) CreatePublishedPort(_ context.Context, record model.PublishedPortRecord) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.updateState(func(state *persistedState) error {
|
||||||
|
for _, port := range state.PublishedPorts {
|
||||||
|
if port.ID == record.ID {
|
||||||
|
return fmt.Errorf("store: published port %q already exists", record.ID)
|
||||||
|
}
|
||||||
|
if port.HostPort == record.HostPort {
|
||||||
|
return fmt.Errorf("store: host port %d already exists", record.HostPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.PublishedPorts = append(state.PublishedPorts, record)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPublishedPort(_ context.Context, id contracthost.PublishedPortID) (*model.PublishedPortRecord, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
state, err := s.readState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for i := range state.PublishedPorts {
|
||||||
|
if state.PublishedPorts[i].ID == id {
|
||||||
|
record := state.PublishedPorts[i]
|
||||||
|
return &record, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) ListPublishedPorts(_ context.Context, machineID contracthost.MachineID) ([]model.PublishedPortRecord, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
state, err := s.readState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := make([]model.PublishedPortRecord, 0, len(state.PublishedPorts))
|
||||||
|
for _, port := range state.PublishedPorts {
|
||||||
|
if machineID != "" && port.MachineID != machineID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, port)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) DeletePublishedPort(_ context.Context, id contracthost.PublishedPortID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.updateState(func(state *persistedState) error {
|
||||||
|
for i := range state.PublishedPorts {
|
||||||
|
if state.PublishedPorts[i].ID == id {
|
||||||
|
state.PublishedPorts = append(state.PublishedPorts[:i], state.PublishedPorts[i+1:]...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ErrNotFound
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) readOperations() (*persistedOperations, error) {
|
func (s *FileStore) readOperations() (*persistedOperations, error) {
|
||||||
var operations persistedOperations
|
var operations persistedOperations
|
||||||
if err := readJSONFile(s.operationsPath, &operations); err != nil {
|
if err := readJSONFile(s.operationsPath, &operations); err != nil {
|
||||||
|
|
@ -422,11 +502,11 @@ func writeJSONFileAtomically(path string, value any) error {
|
||||||
return fmt.Errorf("open temp store file %q: %w", tmpPath, err)
|
return fmt.Errorf("open temp store file %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if _, err := file.Write(payload); err != nil {
|
if _, err := file.Write(payload); err != nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
return fmt.Errorf("write temp store file %q: %w", tmpPath, err)
|
return fmt.Errorf("write temp store file %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := file.Sync(); err != nil {
|
if err := file.Sync(); err != nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
return fmt.Errorf("sync temp store file %q: %w", tmpPath, err)
|
return fmt.Errorf("sync temp store file %q: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
if err := file.Close(); err != nil {
|
if err := file.Close(); err != nil {
|
||||||
|
|
@ -441,7 +521,7 @@ func writeJSONFileAtomically(path string, value any) error {
|
||||||
return fmt.Errorf("open store dir for %q: %w", path, err)
|
return fmt.Errorf("open store dir for %q: %w", path, err)
|
||||||
}
|
}
|
||||||
if err := dir.Sync(); err != nil {
|
if err := dir.Sync(); err != nil {
|
||||||
dir.Close()
|
_ = dir.Close()
|
||||||
return fmt.Errorf("sync store dir for %q: %w", path, err)
|
return fmt.Errorf("sync store dir for %q: %w", path, err)
|
||||||
}
|
}
|
||||||
if err := dir.Close(); err != nil {
|
if err := dir.Close(); err != nil {
|
||||||
|
|
@ -452,10 +532,11 @@ func writeJSONFileAtomically(path string, value any) error {
|
||||||
|
|
||||||
func emptyPersistedState() persistedState {
|
func emptyPersistedState() persistedState {
|
||||||
return persistedState{
|
return persistedState{
|
||||||
Artifacts: []model.ArtifactRecord{},
|
Artifacts: []model.ArtifactRecord{},
|
||||||
Machines: []model.MachineRecord{},
|
Machines: []model.MachineRecord{},
|
||||||
Volumes: []model.VolumeRecord{},
|
Volumes: []model.VolumeRecord{},
|
||||||
Snapshots: []model.SnapshotRecord{},
|
Snapshots: []model.SnapshotRecord{},
|
||||||
|
PublishedPorts: []model.PublishedPortRecord{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -476,6 +557,9 @@ func normalizeState(state *persistedState) {
|
||||||
if state.Snapshots == nil {
|
if state.Snapshots == nil {
|
||||||
state.Snapshots = []model.SnapshotRecord{}
|
state.Snapshots = []model.SnapshotRecord{}
|
||||||
}
|
}
|
||||||
|
if state.PublishedPorts == nil {
|
||||||
|
state.PublishedPorts = []model.PublishedPortRecord{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeOperations(operations *persistedOperations) {
|
func normalizeOperations(operations *persistedOperations) {
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,11 @@ type Store interface {
|
||||||
DeleteOperation(context.Context, contracthost.MachineID) error
|
DeleteOperation(context.Context, contracthost.MachineID) error
|
||||||
CreateSnapshot(context.Context, model.SnapshotRecord) error
|
CreateSnapshot(context.Context, model.SnapshotRecord) error
|
||||||
GetSnapshot(context.Context, contracthost.SnapshotID) (*model.SnapshotRecord, error)
|
GetSnapshot(context.Context, contracthost.SnapshotID) (*model.SnapshotRecord, error)
|
||||||
|
ListSnapshots(context.Context) ([]model.SnapshotRecord, error)
|
||||||
ListSnapshotsByMachine(context.Context, contracthost.MachineID) ([]model.SnapshotRecord, error)
|
ListSnapshotsByMachine(context.Context, contracthost.MachineID) ([]model.SnapshotRecord, error)
|
||||||
DeleteSnapshot(context.Context, contracthost.SnapshotID) error
|
DeleteSnapshot(context.Context, contracthost.SnapshotID) error
|
||||||
|
CreatePublishedPort(context.Context, model.PublishedPortRecord) error
|
||||||
|
GetPublishedPort(context.Context, contracthost.PublishedPortID) (*model.PublishedPortRecord, error)
|
||||||
|
ListPublishedPorts(context.Context, contracthost.MachineID) ([]model.PublishedPortRecord, error)
|
||||||
|
DeletePublishedPort(context.Context, contracthost.PublishedPortID) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
49
main.go
49
main.go
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -10,6 +11,8 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
|
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/daemon"
|
"github.com/getcompanion-ai/computer-host/internal/daemon"
|
||||||
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
"github.com/getcompanion-ai/computer-host/internal/firecracker"
|
||||||
|
|
@ -56,19 +59,49 @@ func main() {
|
||||||
exit(err)
|
exit(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := net.Listen("unix", cfg.SocketPath)
|
unixListener, err := net.Listen("unix", cfg.SocketPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
exit(err)
|
exit(err)
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer func() {
|
||||||
|
_ = unixListener.Close()
|
||||||
server := &http.Server{Handler: handler.Routes()}
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
_ = server.Shutdown(context.Background())
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
servers := []*http.Server{{Handler: handler.Routes()}}
|
||||||
|
listeners := []net.Listener{unixListener}
|
||||||
|
if cfg.HTTPAddr != "" {
|
||||||
|
httpListener, err := net.Listen("tcp", cfg.HTTPAddr)
|
||||||
|
if err != nil {
|
||||||
|
exit(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = httpListener.Close()
|
||||||
|
}()
|
||||||
|
servers = append(servers, &http.Server{Handler: handler.Routes()})
|
||||||
|
listeners = append(listeners, httpListener)
|
||||||
|
}
|
||||||
|
|
||||||
|
group, groupCtx := errgroup.WithContext(ctx)
|
||||||
|
for i := range servers {
|
||||||
|
server := servers[i]
|
||||||
|
listener := listeners[i]
|
||||||
|
group.Go(func() error {
|
||||||
|
err := server.Serve(listener)
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
group.Go(func() error {
|
||||||
|
<-groupCtx.Done()
|
||||||
|
for _, server := range servers {
|
||||||
|
_ = server.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := group.Wait(); err != nil {
|
||||||
exit(err)
|
exit(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue