feat(contracts): add published ports, snapshot request, and storage report types

This commit is contained in:
Harivansh Rathi 2026-04-09 14:05:59 +00:00
parent 501ae2abd5
commit 26b5d2966d
20 changed files with 893 additions and 81 deletions

View file

@ -3,6 +3,7 @@ package daemon
import (
"context"
"fmt"
"net"
"os"
"sync"
"time"
@ -45,6 +46,9 @@ type Daemon struct {
locksMu sync.Mutex
machineLocks map[contracthost.MachineID]*sync.Mutex
artifactLocks map[string]*sync.Mutex
publishedPortsMu sync.Mutex
publishedPortListeners map[contracthost.PublishedPortID]net.Listener
}
func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, error) {
@ -69,6 +73,7 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
reconfigureGuestIdentity: nil,
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
artifactLocks: make(map[string]*sync.Mutex),
publishedPortListeners: make(map[contracthost.PublishedPortID]net.Listener),
}
daemon.reconfigureGuestIdentity = daemon.reconfigureGuestIdentityOverSSH
if err := daemon.ensureBackendSSHKeyPair(); err != nil {

View file

@ -73,9 +73,13 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer sshListener.Close()
defer func() {
_ = sshListener.Close()
}()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer vncListener.Close()
defer func() {
_ = vncListener.Close()
}()
startedAt := time.Unix(1700000005, 0).UTC()
runtime := &fakeRuntime{
@ -339,9 +343,13 @@ func TestRestoreSnapshotUsesSnapshotMetadataWithoutSourceMachine(t *testing.T) {
}
sshListener := listenTestPort(t, int(defaultSSHPort))
defer sshListener.Close()
defer func() {
_ = sshListener.Close()
}()
vncListener := listenTestPort(t, int(defaultVNCPort))
defer vncListener.Close()
defer func() {
_ = vncListener.Close()
}()
startedAt := time.Unix(1700000099, 0).UTC()
runtime := &fakeRuntime{

View file

@ -53,7 +53,9 @@ func cloneFile(source string, target string) error {
if err != nil {
return fmt.Errorf("open source file %q: %w", source, err)
}
defer sourceFile.Close()
defer func() {
_ = sourceFile.Close()
}()
sourceInfo, err := sourceFile.Stat()
if err != nil {
@ -67,15 +69,15 @@ func cloneFile(source string, target string) error {
}
if _, err := writeSparseFile(targetFile, sourceFile); err != nil {
targetFile.Close()
_ = targetFile.Close()
return fmt.Errorf("copy %q to %q: %w", source, tmpPath, err)
}
if err := targetFile.Truncate(sourceInfo.Size()); err != nil {
targetFile.Close()
_ = targetFile.Close()
return fmt.Errorf("truncate target file %q: %w", tmpPath, err)
}
if err := targetFile.Sync(); err != nil {
targetFile.Close()
_ = targetFile.Close()
return fmt.Errorf("sync target file %q: %w", tmpPath, err)
}
if err := targetFile.Close(); err != nil {
@ -108,7 +110,9 @@ func downloadFile(ctx context.Context, rawURL string, path string) error {
if err != nil {
return fmt.Errorf("download %q: %w", rawURL, err)
}
defer response.Body.Close()
defer func() {
_ = response.Body.Close()
}()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("download %q: status %d", rawURL, response.StatusCode)
}
@ -121,15 +125,15 @@ func downloadFile(ctx context.Context, rawURL string, path string) error {
size, err := writeSparseFile(file, response.Body)
if err != nil {
file.Close()
_ = file.Close()
return fmt.Errorf("write download target %q: %w", tmpPath, err)
}
if err := file.Truncate(size); err != nil {
file.Close()
_ = file.Close()
return fmt.Errorf("truncate download target %q: %w", tmpPath, err)
}
if err := file.Sync(); err != nil {
file.Close()
_ = file.Close()
return fmt.Errorf("sync download target %q: %w", tmpPath, err)
}
if err := file.Close(); err != nil {
@ -267,7 +271,9 @@ func injectGuestConfig(ctx context.Context, imagePath string, config *contractho
if err != nil {
return fmt.Errorf("create guest config staging dir: %w", err)
}
defer os.RemoveAll(stagingDir)
defer func() {
_ = os.RemoveAll(stagingDir)
}()
if len(config.AuthorizedKeys) > 0 {
authorizedKeysPath := filepath.Join(stagingDir, "authorized_keys")
@ -306,7 +312,9 @@ func injectMachineIdentity(ctx context.Context, imagePath string, machineID cont
if err != nil {
return fmt.Errorf("create machine identity staging dir: %w", err)
}
defer os.RemoveAll(stagingDir)
defer func() {
_ = os.RemoveAll(stagingDir)
}()
identityFiles := map[string]string{
"/etc/microagent/machine-name": machineName + "\n",
@ -368,6 +376,18 @@ func machineToContract(record model.MachineRecord) contracthost.Machine {
}
}
func publishedPortToContract(record model.PublishedPortRecord) contracthost.PublishedPort {
return contracthost.PublishedPort{
ID: record.ID,
MachineID: record.MachineID,
Name: record.Name,
Port: record.Port,
HostPort: record.HostPort,
Protocol: record.Protocol,
CreatedAt: record.CreatedAt,
}
}
func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState {
phase := firecracker.PhaseStopped
switch record.Phase {
@ -426,6 +446,13 @@ func validateMachineID(machineID contracthost.MachineID) error {
return nil
}
func validateSnapshotID(snapshotID contracthost.SnapshotID) error {
if strings.TrimSpace(string(snapshotID)) == "" {
return fmt.Errorf("snapshot_id is required")
}
return nil
}
func validateDownloadURL(field string, raw string) error {
value := strings.TrimSpace(raw)
if value == "" {
@ -450,7 +477,7 @@ func syncDir(path string) error {
return fmt.Errorf("open dir %q: %w", path, err)
}
if err := dir.Sync(); err != nil {
dir.Close()
_ = dir.Close()
return fmt.Errorf("sync dir %q: %w", path, err)
}
if err := dir.Close(); err != nil {

View file

@ -19,15 +19,15 @@ func TestCloneFilePreservesSparseDiskUsage(t *testing.T) {
t.Fatalf("open source file: %v", err)
}
if _, err := sourceFile.Write([]byte("head")); err != nil {
sourceFile.Close()
_ = sourceFile.Close()
t.Fatalf("write source prefix: %v", err)
}
if _, err := sourceFile.Seek(32<<20, io.SeekStart); err != nil {
sourceFile.Close()
_ = sourceFile.Close()
t.Fatalf("seek source hole: %v", err)
}
if _, err := sourceFile.Write([]byte("tail")); err != nil {
sourceFile.Close()
_ = sourceFile.Close()
t.Fatalf("write source suffix: %v", err)
}
if err := sourceFile.Close(); err != nil {

View file

@ -39,6 +39,88 @@ func (d *Daemon) ListMachines(ctx context.Context) (*contracthost.ListMachinesRe
return &contracthost.ListMachinesResponse{Machines: machines}, nil
}
func (d *Daemon) StartMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
unlock := d.lockMachine(id)
defer unlock()
record, err := d.store.GetMachine(ctx, id)
if err != nil {
return nil, err
}
if record.Phase == contracthost.MachinePhaseRunning {
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
if record.Phase != contracthost.MachinePhaseStopped {
return nil, fmt.Errorf("machine %q is not startable from phase %q", id, record.Phase)
}
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
MachineID: id,
Type: model.MachineOperationStart,
StartedAt: time.Now().UTC(),
}); err != nil {
return nil, err
}
clearOperation := false
defer func() {
if clearOperation {
_ = d.store.DeleteOperation(context.Background(), id)
}
}()
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
if err != nil {
return nil, err
}
artifact, err := d.store.GetArtifact(ctx, record.Artifact)
if err != nil {
return nil, err
}
userVolumes, err := d.loadAttachableUserVolumes(ctx, id, record.UserVolumeIDs)
if err != nil {
return nil, err
}
spec, err := d.buildMachineSpec(id, artifact, userVolumes, systemVolume.Path)
if err != nil {
return nil, err
}
usedNetworks, err := d.listRunningNetworks(ctx, id)
if err != nil {
return nil, err
}
state, err := d.runtime.Boot(ctx, spec, usedNetworks)
if err != nil {
return nil, err
}
ports := defaultMachinePorts()
if err := waitForGuestReady(ctx, state.RuntimeHost, ports); err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
record.RuntimeHost = state.RuntimeHost
record.TapDevice = state.TapName
record.Ports = ports
record.Phase = contracthost.MachinePhaseRunning
record.Error = ""
record.PID = state.PID
record.SocketPath = state.SocketPath
record.StartedAt = state.StartedAt
if err := d.store.UpdateMachine(ctx, *record); err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
d.stopPublishedPortsForMachine(id)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
clearOperation = true
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
func (d *Daemon) StopMachine(ctx context.Context, id contracthost.MachineID) error {
unlock := d.lockMachine(id)
defer unlock()
@ -120,6 +202,10 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
if err := d.reconcileCreate(ctx, operation.MachineID); err != nil {
return err
}
case model.MachineOperationStart:
if err := d.reconcileStart(ctx, operation.MachineID); err != nil {
return err
}
case model.MachineOperationStop:
if err := d.reconcileStop(ctx, operation.MachineID); err != nil {
return err
@ -149,6 +235,13 @@ func (d *Daemon) Reconcile(ctx context.Context) error {
if _, err := d.reconcileMachine(ctx, record.ID); err != nil {
return err
}
if record.Phase == contracthost.MachinePhaseRunning {
if err := d.ensurePublishedPortsForMachine(ctx, record); err != nil {
return err
}
} else {
d.stopPublishedPortsForMachine(record.ID)
}
}
return nil
}
@ -218,6 +311,26 @@ func (d *Daemon) reconcileStop(ctx context.Context, machineID contracthost.Machi
return d.store.DeleteOperation(ctx, machineID)
}
func (d *Daemon) reconcileStart(ctx context.Context, machineID contracthost.MachineID) error {
record, err := d.store.GetMachine(ctx, machineID)
if err == store.ErrNotFound {
return d.store.DeleteOperation(ctx, machineID)
}
if err != nil {
return err
}
if record.Phase == contracthost.MachinePhaseRunning {
if err := d.ensurePublishedPortsForMachine(ctx, *record); err != nil {
return err
}
return d.store.DeleteOperation(ctx, machineID)
}
if _, err := d.StartMachine(ctx, machineID); err != nil {
return err
}
return d.store.DeleteOperation(ctx, machineID)
}
func (d *Daemon) reconcileDelete(ctx context.Context, machineID contracthost.MachineID) error {
record, err := d.store.GetMachine(ctx, machineID)
if err == store.ErrNotFound {
@ -266,6 +379,7 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
if err := d.runtime.Delete(ctx, *state); err != nil {
return nil, err
}
d.stopPublishedPortsForMachine(record.ID)
record.Phase = contracthost.MachinePhaseFailed
record.Error = state.Error
record.PID = 0
@ -280,9 +394,15 @@ func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.Ma
}
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
d.stopPublishedPortsForMachine(record.ID)
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
return err
}
if ports, err := d.store.ListPublishedPorts(ctx, record.ID); err == nil {
for _, port := range ports {
_ = d.store.DeletePublishedPort(ctx, port.ID)
}
}
if err := d.detachVolumesForMachine(ctx, record.ID); err != nil {
return err
}
@ -304,6 +424,7 @@ func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineR
}
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
d.stopPublishedPortsForMachine(record.ID)
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
return err
}

View file

@ -0,0 +1,246 @@
package daemon
import (
"context"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const (
minPublishedHostPort = uint16(20000)
maxPublishedHostPort = uint16(39999)
)
func (d *Daemon) CreatePublishedPort(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error) {
if strings.TrimSpace(string(req.PublishedPortID)) == "" {
return nil, fmt.Errorf("published_port_id is required")
}
if req.Port == 0 {
return nil, fmt.Errorf("port must be greater than zero")
}
if req.Protocol == "" {
req.Protocol = contracthost.PortProtocolTCP
}
if req.Protocol != contracthost.PortProtocolTCP {
return nil, fmt.Errorf("unsupported protocol %q", req.Protocol)
}
unlock := d.lockMachine(machineID)
defer unlock()
record, err := d.store.GetMachine(ctx, machineID)
if err != nil {
return nil, err
}
if record.Phase != contracthost.MachinePhaseRunning {
return nil, fmt.Errorf("machine %q is not running", machineID)
}
if _, err := d.store.GetPublishedPort(ctx, req.PublishedPortID); err == nil {
return nil, fmt.Errorf("published port %q already exists", req.PublishedPortID)
} else if err != nil && err != store.ErrNotFound {
return nil, err
}
hostPort, err := d.allocatePublishedHostPort(ctx)
if err != nil {
return nil, err
}
published := model.PublishedPortRecord{
ID: req.PublishedPortID,
MachineID: machineID,
Name: strings.TrimSpace(req.Name),
Port: req.Port,
HostPort: hostPort,
Protocol: req.Protocol,
CreatedAt: time.Now().UTC(),
}
if err := d.startPublishedPortProxy(published, record.RuntimeHost); err != nil {
return nil, err
}
storeCreated := false
defer func() {
if !storeCreated {
d.stopPublishedPortProxy(req.PublishedPortID)
}
}()
if err := d.store.CreatePublishedPort(ctx, published); err != nil {
return nil, err
}
storeCreated = true
return &contracthost.CreatePublishedPortResponse{Port: publishedPortToContract(published)}, nil
}
func (d *Daemon) ListPublishedPorts(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListPublishedPortsResponse, error) {
ports, err := d.store.ListPublishedPorts(ctx, machineID)
if err != nil {
return nil, err
}
response := &contracthost.ListPublishedPortsResponse{Ports: make([]contracthost.PublishedPort, 0, len(ports))}
for _, port := range ports {
response.Ports = append(response.Ports, publishedPortToContract(port))
}
return response, nil
}
func (d *Daemon) DeletePublishedPort(ctx context.Context, machineID contracthost.MachineID, portID contracthost.PublishedPortID) error {
unlock := d.lockMachine(machineID)
defer unlock()
record, err := d.store.GetPublishedPort(ctx, portID)
if err != nil {
if err == store.ErrNotFound {
return nil
}
return err
}
if record.MachineID != machineID {
return fmt.Errorf("published port %q does not belong to machine %q", portID, machineID)
}
d.stopPublishedPortProxy(portID)
return d.store.DeletePublishedPort(ctx, portID)
}
func (d *Daemon) ensurePublishedPortsForMachine(ctx context.Context, machine model.MachineRecord) error {
if machine.Phase != contracthost.MachinePhaseRunning || strings.TrimSpace(machine.RuntimeHost) == "" {
return nil
}
ports, err := d.store.ListPublishedPorts(ctx, machine.ID)
if err != nil {
return err
}
for _, port := range ports {
if err := d.startPublishedPortProxy(port, machine.RuntimeHost); err != nil {
return err
}
}
return nil
}
func (d *Daemon) stopPublishedPortsForMachine(machineID contracthost.MachineID) {
ports, err := d.store.ListPublishedPorts(context.Background(), machineID)
if err != nil {
return
}
for _, port := range ports {
d.stopPublishedPortProxy(port.ID)
}
}
func (d *Daemon) allocatePublishedHostPort(ctx context.Context) (uint16, error) {
ports, err := d.store.ListPublishedPorts(ctx, "")
if err != nil {
return 0, err
}
used := make(map[uint16]struct{}, len(ports))
for _, port := range ports {
used[port.HostPort] = struct{}{}
}
for hostPort := minPublishedHostPort; hostPort <= maxPublishedHostPort; hostPort++ {
if _, exists := used[hostPort]; exists {
continue
}
return hostPort, nil
}
return 0, fmt.Errorf("no published host ports are available")
}
func (d *Daemon) startPublishedPortProxy(port model.PublishedPortRecord, runtimeHost string) error {
targetHost := strings.TrimSpace(runtimeHost)
if targetHost == "" {
return fmt.Errorf("runtime host is required for published port %q", port.ID)
}
d.publishedPortsMu.Lock()
if _, exists := d.publishedPortListeners[port.ID]; exists {
d.publishedPortsMu.Unlock()
return nil
}
listener, err := net.Listen("tcp", ":"+strconv.Itoa(int(port.HostPort)))
if err != nil {
d.publishedPortsMu.Unlock()
return fmt.Errorf("listen on host port %d: %w", port.HostPort, err)
}
d.publishedPortListeners[port.ID] = listener
d.publishedPortsMu.Unlock()
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(int(port.Port)))
go d.servePublishedPortProxy(port.ID, listener, targetAddr)
return nil
}
func (d *Daemon) servePublishedPortProxy(portID contracthost.PublishedPortID, listener net.Listener, targetAddr string) {
for {
conn, err := listener.Accept()
if err != nil {
if isClosedNetworkError(err) {
return
}
continue
}
go proxyPublishedPortConnection(conn, targetAddr)
}
}
func proxyPublishedPortConnection(source net.Conn, targetAddr string) {
defer func() {
_ = source.Close()
}()
target, err := net.DialTimeout("tcp", targetAddr, 5*time.Second)
if err != nil {
return
}
defer func() {
_ = target.Close()
}()
done := make(chan struct{}, 1)
go func() {
_, _ = io.Copy(target, source)
closeWrite(target)
done <- struct{}{}
}()
_, _ = io.Copy(source, target)
closeWrite(source)
<-done
}
func closeWrite(conn net.Conn) {
type closeWriter interface {
CloseWrite() error
}
if value, ok := conn.(closeWriter); ok {
_ = value.CloseWrite()
}
}
func isClosedNetworkError(err error) bool {
if err == nil {
return false
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "use of closed network connection")
}
func (d *Daemon) stopPublishedPortProxy(portID contracthost.PublishedPortID) {
d.publishedPortsMu.Lock()
listener, ok := d.publishedPortListeners[portID]
if ok {
delete(d.publishedPortListeners, portID)
}
d.publishedPortsMu.Unlock()
if ok {
_ = listener.Close()
}
}

View file

@ -2,8 +2,6 @@ package daemon
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"os"
@ -18,10 +16,14 @@ import (
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID) (*contracthost.CreateSnapshotResponse, error) {
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
unlock := d.lockMachine(machineID)
defer unlock()
if err := validateSnapshotID(req.SnapshotID); err != nil {
return nil, err
}
record, err := d.store.GetMachine(ctx, machineID)
if err != nil {
return nil, err
@ -39,7 +41,7 @@ func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.Mach
}
}
snapshotID := contracthost.SnapshotID(generateID())
snapshotID := req.SnapshotID
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
MachineID: machineID,
@ -356,14 +358,6 @@ func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecra
return false
}
func generateID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("generate id: %v", err))
}
return hex.EncodeToString(b)
}
// moveFile copies src to dst then removes src. Works across filesystem boundaries
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
func moveFile(src, dst string) error {
@ -371,7 +365,9 @@ func moveFile(src, dst string) error {
if err != nil {
return err
}
defer in.Close()
defer func() {
_ = in.Close()
}()
out, err := os.Create(dst)
if err != nil {
@ -379,7 +375,7 @@ func moveFile(src, dst string) error {
}
if _, err := io.Copy(out, in); err != nil {
out.Close()
_ = out.Close()
_ = os.Remove(dst)
return err
}

View file

@ -0,0 +1,141 @@
package daemon
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"time"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) GetStorageReport(ctx context.Context) (*contracthost.GetStorageReportResponse, error) {
volumes, err := d.store.ListVolumes(ctx)
if err != nil {
return nil, err
}
snapshots, err := d.store.ListSnapshots(ctx)
if err != nil {
return nil, err
}
publishedPorts, err := d.store.ListPublishedPorts(ctx, "")
if err != nil {
return nil, err
}
pools := make([]contracthost.StoragePoolUsage, 0, 5)
totalBytes := int64(0)
addPool := func(pool contracthost.StoragePool, path string) error {
bytes, err := directorySize(path)
if err != nil {
return err
}
pools = append(pools, contracthost.StoragePoolUsage{Pool: pool, Bytes: bytes})
totalBytes += bytes
return nil
}
for _, pool := range []struct {
name contracthost.StoragePool
path string
}{
{name: contracthost.StoragePoolArtifacts, path: d.config.ArtifactsDir},
{name: contracthost.StoragePoolMachineDisks, path: d.config.MachineDisksDir},
{name: contracthost.StoragePoolSnapshots, path: d.config.SnapshotsDir},
{name: contracthost.StoragePoolState, path: filepath.Dir(d.config.StatePath)},
} {
if err := addPool(pool.name, pool.path); err != nil {
return nil, err
}
}
machineUsage := make([]contracthost.MachineStorageUsage, 0, len(volumes))
for _, volume := range volumes {
if volume.AttachedMachineID == nil || volume.Kind != contracthost.VolumeKindSystem {
continue
}
bytes, err := fileSize(volume.Path)
if err != nil {
return nil, err
}
machineUsage = append(machineUsage, contracthost.MachineStorageUsage{
MachineID: *volume.AttachedMachineID,
SystemBytes: bytes,
})
}
snapshotUsage := make([]contracthost.SnapshotStorageUsage, 0, len(snapshots))
for _, snapshot := range snapshots {
bytes, err := fileSize(snapshot.MemFilePath)
if err != nil {
return nil, err
}
stateBytes, err := fileSize(snapshot.StateFilePath)
if err != nil {
return nil, err
}
bytes += stateBytes
for _, diskPath := range snapshot.DiskPaths {
diskBytes, err := fileSize(diskPath)
if err != nil {
return nil, err
}
bytes += diskBytes
}
snapshotUsage = append(snapshotUsage, contracthost.SnapshotStorageUsage{
SnapshotID: snapshot.ID,
Bytes: bytes,
})
}
return &contracthost.GetStorageReportResponse{
Report: contracthost.StorageReport{
GeneratedAt: time.Now().UTC(),
TotalBytes: totalBytes,
Pools: pools,
Machines: machineUsage,
Snapshots: snapshotUsage,
PublishedPorts: int64(len(publishedPorts)),
},
}, nil
}
func directorySize(root string) (int64, error) {
if root == "" {
return 0, nil
}
if _, err := os.Stat(root); err != nil {
if os.IsNotExist(err) {
return 0, nil
}
return 0, fmt.Errorf("stat %q: %w", root, err)
}
var total int64
if err := filepath.WalkDir(root, func(path string, entry fs.DirEntry, err error) error {
if err != nil {
return err
}
if entry.IsDir() {
return nil
}
info, err := entry.Info()
if err != nil {
return err
}
total += info.Size()
return nil
}); err != nil {
return 0, fmt.Errorf("walk %q: %w", root, err)
}
return total, nil
}
func fileSize(path string) (int64, error) {
info, err := os.Stat(path)
if err != nil {
return 0, fmt.Errorf("stat %q: %w", path, err)
}
return info.Size(), nil
}