host daemon (#2)

* feat: host daemon api scaffold

* fix: use sparse writes

* fix: unix socket length (<108 bytes)
This commit is contained in:
Hari 2026-04-08 11:23:19 -04:00 committed by GitHub
parent 4028bb5a1d
commit e2f9e54970
21 changed files with 2111 additions and 372 deletions

View file

@ -3,8 +3,8 @@ package host
import "time" import "time"
type ArtifactRef struct { type ArtifactRef struct {
ID ArtifactID `json:"id"` KernelImageURL string `json:"kernel_image_url"`
Version ArtifactVersion `json:"version"` RootFSURL string `json:"rootfs_url"`
} }
type Volume struct { type Volume struct {

View file

@ -1,9 +1,5 @@
package host package host
type ArtifactID string
type ArtifactVersion string
type MachineID string type MachineID string
type MachinePhase string type MachinePhase string
@ -13,12 +9,9 @@ type VolumeID string
type VolumeKind string type VolumeKind string
const ( const (
MachinePhasePending MachinePhase = "pending"
MachinePhaseRunning MachinePhase = "running" MachinePhaseRunning MachinePhase = "running"
MachinePhaseStopping MachinePhase = "stopping"
MachinePhaseStopped MachinePhase = "stopped" MachinePhaseStopped MachinePhase = "stopped"
MachinePhaseFailed MachinePhase = "failed" MachinePhaseFailed MachinePhase = "failed"
MachinePhaseDeleting MachinePhase = "deleting"
) )
const ( const (

View file

@ -3,28 +3,40 @@ package config
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath"
"strings" "strings"
"github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/firecracker"
) )
// Config contains the minimum host-local settings required to boot a machine. const defaultSocketName = "firecracker-host.sock"
// Config contains the host-local daemon settings.
type Config struct { type Config struct {
RootDir string RootDir string
StatePath string
OperationsPath string
ArtifactsDir string
MachineDisksDir string
RuntimeDir string
SocketPath string
FirecrackerBinaryPath string FirecrackerBinaryPath string
JailerBinaryPath string JailerBinaryPath string
KernelImagePath string
RootFSPath string
} }
// Load loads and validates the firecracker-host configuration from the environment. // Load loads and validates the firecracker-host daemon configuration from the environment.
func Load() (Config, error) { func Load() (Config, error) {
rootDir := filepath.Clean(strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_ROOT_DIR")))
cfg := Config{ cfg := Config{
RootDir: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_ROOT_DIR")), RootDir: rootDir,
StatePath: filepath.Join(rootDir, "state", "state.json"),
OperationsPath: filepath.Join(rootDir, "state", "ops.json"),
ArtifactsDir: filepath.Join(rootDir, "artifacts"),
MachineDisksDir: filepath.Join(rootDir, "machine-disks"),
RuntimeDir: filepath.Join(rootDir, "runtime"),
SocketPath: filepath.Join(rootDir, defaultSocketName),
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")), FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")), JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
KernelImagePath: strings.TrimSpace(os.Getenv("FIRECRACKER_GUEST_KERNEL_PATH")),
RootFSPath: strings.TrimSpace(os.Getenv("FIRECRACKER_GUEST_ROOTFS_PATH")),
} }
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return Config{}, err return Config{}, err
@ -43,19 +55,31 @@ func (c Config) Validate() error {
if c.JailerBinaryPath == "" { if c.JailerBinaryPath == "" {
return fmt.Errorf("JAILER_BINARY_PATH is required") return fmt.Errorf("JAILER_BINARY_PATH is required")
} }
if c.KernelImagePath == "" { if strings.TrimSpace(c.StatePath) == "" {
return fmt.Errorf("FIRECRACKER_GUEST_KERNEL_PATH is required") return fmt.Errorf("state path is required")
} }
if c.RootFSPath == "" { if strings.TrimSpace(c.OperationsPath) == "" {
return fmt.Errorf("FIRECRACKER_GUEST_ROOTFS_PATH is required") return fmt.Errorf("operations path is required")
}
if strings.TrimSpace(c.ArtifactsDir) == "" {
return fmt.Errorf("artifacts dir is required")
}
if strings.TrimSpace(c.MachineDisksDir) == "" {
return fmt.Errorf("machine disks dir is required")
}
if strings.TrimSpace(c.RuntimeDir) == "" {
return fmt.Errorf("runtime dir is required")
}
if strings.TrimSpace(c.SocketPath) == "" {
return fmt.Errorf("socket path is required")
} }
return nil return nil
} }
// FirecrackerRuntimeConfig converts the host config into the runtime wrapper's concrete runtime config. // FirecrackerRuntimeConfig converts the daemon config into the Firecracker runtime config.
func (c Config) FirecrackerRuntimeConfig() firecracker.RuntimeConfig { func (c Config) FirecrackerRuntimeConfig() firecracker.RuntimeConfig {
return firecracker.RuntimeConfig{ return firecracker.RuntimeConfig{
RootDir: c.RootDir, RootDir: c.RuntimeDir,
FirecrackerBinaryPath: c.FirecrackerBinaryPath, FirecrackerBinaryPath: c.FirecrackerBinaryPath,
JailerBinaryPath: c.JailerBinaryPath, JailerBinaryPath: c.JailerBinaryPath,
} }

221
internal/daemon/create.go Normal file
View file

@ -0,0 +1,221 @@
package daemon
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) {
if err := validateMachineID(req.MachineID); err != nil {
return nil, err
}
if err := validateArtifactRef(req.Artifact); err != nil {
return nil, err
}
unlock := d.lockMachine(req.MachineID)
defer unlock()
if _, err := d.store.GetMachine(ctx, req.MachineID); err == nil {
return nil, fmt.Errorf("machine %q already exists", req.MachineID)
} else if err != nil && err != store.ErrNotFound {
return nil, err
}
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
MachineID: req.MachineID,
Type: model.MachineOperationCreate,
StartedAt: time.Now().UTC(),
}); err != nil {
return nil, err
}
clearOperation := false
defer func() {
if clearOperation {
_ = d.store.DeleteOperation(context.Background(), req.MachineID)
}
}()
artifact, err := d.ensureArtifact(ctx, req.Artifact)
if err != nil {
return nil, err
}
userVolumes, err := d.loadAttachableUserVolumes(ctx, req.MachineID, req.UserVolumeIDs)
if err != nil {
return nil, err
}
systemVolumePath := d.systemVolumePath(req.MachineID)
if err := os.MkdirAll(filepath.Dir(systemVolumePath), 0o755); err != nil {
return nil, fmt.Errorf("create system volume dir for %q: %w", req.MachineID, err)
}
if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil {
return nil, err
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath)
if err != nil {
return nil, err
}
usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID)
if err != nil {
return nil, err
}
state, err := d.runtime.Boot(ctx, spec, usedNetworks)
if err != nil {
return nil, err
}
now := time.Now().UTC()
systemVolumeRecord := model.VolumeRecord{
ID: d.systemVolumeID(req.MachineID),
Kind: contracthost.VolumeKindSystem,
AttachedMachineID: machineIDPtr(req.MachineID),
SourceArtifact: &req.Artifact,
Pool: model.StoragePoolMachineDisks,
Path: systemVolumePath,
CreatedAt: now,
}
if err := d.store.CreateVolume(ctx, systemVolumeRecord); err != nil {
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
attachedUserVolumeIDs := make([]contracthost.VolumeID, 0, len(userVolumes))
for _, volume := range userVolumes {
volume.AttachedMachineID = machineIDPtr(req.MachineID)
if err := d.store.UpdateVolume(ctx, volume); err != nil {
for _, attachedVolumeID := range attachedUserVolumeIDs {
attachedVolume, getErr := d.store.GetVolume(context.Background(), attachedVolumeID)
if getErr == nil {
attachedVolume.AttachedMachineID = nil
_ = d.store.UpdateVolume(context.Background(), *attachedVolume)
}
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
attachedUserVolumeIDs = append(attachedUserVolumeIDs, volume.ID)
}
record := model.MachineRecord{
ID: req.MachineID,
Artifact: req.Artifact,
SystemVolumeID: systemVolumeRecord.ID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...),
RuntimeHost: state.RuntimeHost,
TapDevice: state.TapName,
Ports: defaultMachinePorts(),
Phase: contracthost.MachinePhaseRunning,
PID: state.PID,
SocketPath: state.SocketPath,
CreatedAt: now,
StartedAt: state.StartedAt,
}
if err := d.store.CreateMachine(ctx, record); err != nil {
for _, volume := range userVolumes {
volume.AttachedMachineID = nil
_ = d.store.UpdateVolume(context.Background(), volume)
}
_ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID)
_ = d.runtime.Delete(context.Background(), *state)
return nil, err
}
clearOperation = true
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
}
func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string) (firecracker.MachineSpec, error) {
drives := make([]firecracker.DriveSpec, 0, len(userVolumes))
for i, volume := range userVolumes {
drives = append(drives, firecracker.DriveSpec{
ID: fmt.Sprintf("user-%d", i),
Path: volume.Path,
ReadOnly: false,
})
}
spec := firecracker.MachineSpec{
ID: firecracker.MachineID(machineID),
VCPUs: defaultGuestVCPUs,
MemoryMiB: defaultGuestMemoryMiB,
KernelImagePath: artifact.KernelImagePath,
RootFSPath: systemVolumePath,
KernelArgs: defaultGuestKernelArgs,
Drives: drives,
}
if err := spec.Validate(); err != nil {
return firecracker.MachineSpec{}, err
}
return spec, nil
}
func (d *Daemon) ensureArtifact(ctx context.Context, ref contracthost.ArtifactRef) (*model.ArtifactRecord, error) {
key := artifactKey(ref)
unlock := d.lockArtifact(key)
defer unlock()
if artifact, err := d.store.GetArtifact(ctx, ref); err == nil {
return artifact, nil
} else if err != store.ErrNotFound {
return nil, err
}
dir := filepath.Join(d.config.ArtifactsDir, key)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create artifact dir %q: %w", dir, err)
}
kernelPath := filepath.Join(dir, "kernel")
rootFSPath := filepath.Join(dir, "rootfs")
if err := downloadFile(ctx, ref.KernelImageURL, kernelPath); err != nil {
return nil, err
}
if err := downloadFile(ctx, ref.RootFSURL, rootFSPath); err != nil {
return nil, err
}
artifact := model.ArtifactRecord{
Ref: ref,
LocalKey: key,
LocalDir: dir,
KernelImagePath: kernelPath,
RootFSPath: rootFSPath,
CreatedAt: time.Now().UTC(),
}
if err := d.store.PutArtifact(ctx, artifact); err != nil {
return nil, err
}
return &artifact, nil
}
func (d *Daemon) loadAttachableUserVolumes(ctx context.Context, machineID contracthost.MachineID, volumeIDs []contracthost.VolumeID) ([]model.VolumeRecord, error) {
volumes := make([]model.VolumeRecord, 0, len(volumeIDs))
for _, volumeID := range volumeIDs {
volume, err := d.store.GetVolume(ctx, volumeID)
if err != nil {
return nil, err
}
if volume.Kind != contracthost.VolumeKindUser {
return nil, fmt.Errorf("volume %q is not a user volume", volumeID)
}
if volume.AttachedMachineID != nil && *volume.AttachedMachineID != machineID {
return nil, fmt.Errorf("volume %q is already attached to machine %q", volumeID, *volume.AttachedMachineID)
}
volumes = append(volumes, *volume)
}
return volumes, nil
}

View file

@ -1,12 +1,95 @@
package daemon package daemon
import ( import (
"context"
"fmt"
"os"
"sync"
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/store" "github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
) )
type Runtime interface{} const (
defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off"
defaultGuestMemoryMiB = int64(512)
defaultGuestVCPUs = int64(1)
defaultSSHPort = uint16(2222)
defaultVNCPort = uint16(6080)
defaultCopyBufferSize = 1024 * 1024
)
type Runtime interface {
Boot(context.Context, firecracker.MachineSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error)
Inspect(firecracker.MachineState) (*firecracker.MachineState, error)
Delete(context.Context, firecracker.MachineState) error
}
type Daemon struct { type Daemon struct {
Store store.Store config appconfig.Config
Runtime Runtime store store.Store
runtime Runtime
locksMu sync.Mutex
machineLocks map[contracthost.MachineID]*sync.Mutex
artifactLocks map[string]*sync.Mutex
}
func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
if store == nil {
return nil, fmt.Errorf("store is required")
}
if runtime == nil {
return nil, fmt.Errorf("runtime is required")
}
for _, dir := range []string{cfg.ArtifactsDir, cfg.MachineDisksDir, cfg.RuntimeDir} {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create daemon dir %q: %w", dir, err)
}
}
return &Daemon{
config: cfg,
store: store,
runtime: runtime,
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
artifactLocks: make(map[string]*sync.Mutex),
}, nil
}
func (d *Daemon) Health(ctx context.Context) (*contracthost.HealthResponse, error) {
if _, err := d.store.ListMachines(ctx); err != nil {
return nil, err
}
return &contracthost.HealthResponse{OK: true}, nil
}
func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() {
d.locksMu.Lock()
lock, ok := d.machineLocks[machineID]
if !ok {
lock = &sync.Mutex{}
d.machineLocks[machineID] = lock
}
d.locksMu.Unlock()
lock.Lock()
return lock.Unlock
}
func (d *Daemon) lockArtifact(key string) func() {
d.locksMu.Lock()
lock, ok := d.artifactLocks[key]
if !ok {
lock = &sync.Mutex{}
d.artifactLocks[key] = lock
}
d.locksMu.Unlock()
lock.Lock()
return lock.Unlock
} }

View file

@ -0,0 +1,211 @@
package daemon
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type fakeRuntime struct {
bootState firecracker.MachineState
bootCalls int
deleteCalls []firecracker.MachineState
lastSpec firecracker.MachineSpec
}
func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) {
f.bootCalls++
f.lastSpec = spec
state := f.bootState
return &state, nil
}
func (f *fakeRuntime) Inspect(state firecracker.MachineState) (*firecracker.MachineState, error) {
copy := state
return &copy, nil
}
func (f *fakeRuntime) Delete(_ context.Context, state firecracker.MachineState) error {
f.deleteCalls = append(f.deleteCalls, state)
return nil
}
func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
t.Parallel()
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
startedAt := time.Unix(1700000005, 0).UTC()
runtime := &fakeRuntime{
bootState: firecracker.MachineState{
ID: "vm-1",
Phase: firecracker.PhaseRunning,
PID: 4321,
RuntimeHost: "172.16.0.2",
SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-1", "root", "run", "firecracker.sock"),
TapName: "fctap0",
StartedAt: &startedAt,
},
}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
kernelPayload := []byte("kernel-image")
rootFSPayload := []byte("rootfs-image")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/kernel":
_, _ = w.Write(kernelPayload)
case "/rootfs":
_, _ = w.Write(rootFSPayload)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
response, err := hostDaemon.CreateMachine(context.Background(), contracthost.CreateMachineRequest{
MachineID: "vm-1",
Artifact: contracthost.ArtifactRef{
KernelImageURL: server.URL + "/kernel",
RootFSURL: server.URL + "/rootfs",
},
})
if err != nil {
t.Fatalf("create machine: %v", err)
}
if response.Machine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase)
}
if response.Machine.RuntimeHost != "172.16.0.2" {
t.Fatalf("runtime host mismatch: got %q", response.Machine.RuntimeHost)
}
if len(response.Machine.Ports) != 2 {
t.Fatalf("machine ports mismatch: got %d want 2", len(response.Machine.Ports))
}
if runtime.bootCalls != 1 {
t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls)
}
if runtime.lastSpec.KernelImagePath == "" || runtime.lastSpec.RootFSPath == "" {
t.Fatalf("runtime spec paths not populated: %#v", runtime.lastSpec)
}
if _, err := os.Stat(runtime.lastSpec.KernelImagePath); err != nil {
t.Fatalf("kernel artifact not staged: %v", err)
}
if _, err := os.Stat(runtime.lastSpec.RootFSPath); err != nil {
t.Fatalf("system disk not staged: %v", err)
}
artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact)
if err != nil {
t.Fatalf("get artifact: %v", err)
}
if artifact.KernelImagePath == "" || artifact.RootFSPath == "" {
t.Fatalf("artifact paths missing: %#v", artifact)
}
if payload, err := os.ReadFile(artifact.KernelImagePath); err != nil {
t.Fatalf("read kernel artifact: %v", err)
} else if string(payload) != string(kernelPayload) {
t.Fatalf("kernel artifact payload mismatch: got %q", string(payload))
}
machine, err := fileStore.GetMachine(context.Background(), "vm-1")
if err != nil {
t.Fatalf("get machine: %v", err)
}
if machine.SystemVolumeID != "vm-1-system" {
t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID)
}
operations, err := fileStore.ListOperations(context.Background())
if err != nil {
t.Fatalf("list operations: %v", err)
}
if len(operations) != 0 {
t.Fatalf("operation journal should be empty after success: got %d entries", len(operations))
}
}
func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) {
t.Parallel()
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
hostDaemon, err := New(cfg, fileStore, &fakeRuntime{})
if err != nil {
t.Fatalf("create daemon: %v", err)
}
_, err = hostDaemon.CreateMachine(context.Background(), contracthost.CreateMachineRequest{
MachineID: "vm-1",
Artifact: contracthost.ArtifactRef{
KernelImageURL: "file:///kernel",
RootFSURL: "https://example.com/rootfs",
},
})
if err == nil {
t.Fatal("expected create machine to fail for non-http artifact url")
}
if got := err.Error(); got != "artifact.kernel_image_url must use http or https" {
t.Fatalf("unexpected error: %q", got)
}
}
func TestDeleteMachineMissingIsNoOp(t *testing.T) {
t.Parallel()
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
runtime := &fakeRuntime{}
hostDaemon, err := New(cfg, fileStore, runtime)
if err != nil {
t.Fatalf("create daemon: %v", err)
}
if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil {
t.Fatalf("delete missing machine: %v", err)
}
if len(runtime.deleteCalls) != 0 {
t.Fatalf("delete runtime should not be called for missing machine")
}
}
func testConfig(root string) appconfig.Config {
return appconfig.Config{
RootDir: root,
StatePath: filepath.Join(root, "state", "state.json"),
OperationsPath: filepath.Join(root, "state", "ops.json"),
ArtifactsDir: filepath.Join(root, "artifacts"),
MachineDisksDir: filepath.Join(root, "machine-disks"),
RuntimeDir: filepath.Join(root, "runtime"),
SocketPath: filepath.Join(root, "firecracker-host.sock"),
FirecrackerBinaryPath: "/usr/bin/firecracker",
JailerBinaryPath: "/usr/bin/jailer",
}
}

274
internal/daemon/files.go Normal file
View file

@ -0,0 +1,274 @@
package daemon
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) systemVolumeID(machineID contracthost.MachineID) contracthost.VolumeID {
return contracthost.VolumeID(fmt.Sprintf("%s-system", machineID))
}
func (d *Daemon) systemVolumePath(machineID contracthost.MachineID) string {
return filepath.Join(d.config.MachineDisksDir, string(machineID), "system.img")
}
func (d *Daemon) machineRuntimeBaseDir(machineID contracthost.MachineID) string {
return filepath.Join(d.config.RuntimeDir, "machines", string(machineID))
}
func artifactKey(ref contracthost.ArtifactRef) string {
sum := sha256.Sum256([]byte(ref.KernelImageURL + "\n" + ref.RootFSURL))
return hex.EncodeToString(sum[:])
}
func cloneFile(source string, target string) error {
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return fmt.Errorf("create target dir for %q: %w", target, err)
}
sourceFile, err := os.Open(source)
if err != nil {
return fmt.Errorf("open source file %q: %w", source, err)
}
defer sourceFile.Close()
sourceInfo, err := sourceFile.Stat()
if err != nil {
return fmt.Errorf("stat source file %q: %w", source, err)
}
tmpPath := target + ".tmp"
targetFile, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open target file %q: %w", tmpPath, err)
}
if _, err := writeSparseFile(targetFile, sourceFile); err != nil {
targetFile.Close()
return fmt.Errorf("copy %q to %q: %w", source, tmpPath, err)
}
if err := targetFile.Truncate(sourceInfo.Size()); err != nil {
targetFile.Close()
return fmt.Errorf("truncate target file %q: %w", tmpPath, err)
}
if err := targetFile.Sync(); err != nil {
targetFile.Close()
return fmt.Errorf("sync target file %q: %w", tmpPath, err)
}
if err := targetFile.Close(); err != nil {
return fmt.Errorf("close target file %q: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, target); err != nil {
return fmt.Errorf("rename target file %q to %q: %w", tmpPath, target, err)
}
if err := syncDir(filepath.Dir(target)); err != nil {
return err
}
return nil
}
func downloadFile(ctx context.Context, rawURL string, path string) error {
if _, err := os.Stat(path); err == nil {
return nil
} else if !os.IsNotExist(err) {
return fmt.Errorf("stat download target %q: %w", path, err)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("create download dir for %q: %w", path, err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return fmt.Errorf("build download request for %q: %w", rawURL, err)
}
response, err := http.DefaultClient.Do(request)
if err != nil {
return fmt.Errorf("download %q: %w", rawURL, err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("download %q: status %d", rawURL, response.StatusCode)
}
tmpPath := path + ".tmp"
file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open download target %q: %w", tmpPath, err)
}
size, err := writeSparseFile(file, response.Body)
if err != nil {
file.Close()
return fmt.Errorf("write download target %q: %w", tmpPath, err)
}
if err := file.Truncate(size); err != nil {
file.Close()
return fmt.Errorf("truncate download target %q: %w", tmpPath, err)
}
if err := file.Sync(); err != nil {
file.Close()
return fmt.Errorf("sync download target %q: %w", tmpPath, err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close download target %q: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, path); err != nil {
return fmt.Errorf("rename download target %q to %q: %w", tmpPath, path, err)
}
if err := syncDir(filepath.Dir(path)); err != nil {
return err
}
return nil
}
func writeSparseFile(targetFile *os.File, source io.Reader) (int64, error) {
buffer := make([]byte, defaultCopyBufferSize)
var size int64
for {
count, err := source.Read(buffer)
if count > 0 {
chunk := buffer[:count]
if isZeroChunk(chunk) {
if _, seekErr := targetFile.Seek(int64(count), io.SeekCurrent); seekErr != nil {
return size, seekErr
}
} else {
if _, writeErr := targetFile.Write(chunk); writeErr != nil {
return size, writeErr
}
}
size += int64(count)
}
if err == nil {
continue
}
if err == io.EOF {
return size, nil
}
return size, err
}
}
func isZeroChunk(chunk []byte) bool {
for _, value := range chunk {
if value != 0 {
return false
}
}
return true
}
func defaultMachinePorts() []contracthost.MachinePort {
return []contracthost.MachinePort{
{Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, Protocol: contracthost.PortProtocolTCP},
{Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, Protocol: contracthost.PortProtocolTCP},
}
}
func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID {
value := machineID
return &value
}
func machineToContract(record model.MachineRecord) contracthost.Machine {
return contracthost.Machine{
ID: record.ID,
Artifact: record.Artifact,
SystemVolumeID: record.SystemVolumeID,
UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...),
RuntimeHost: record.RuntimeHost,
Ports: append([]contracthost.MachinePort(nil), record.Ports...),
Phase: record.Phase,
Error: record.Error,
CreatedAt: record.CreatedAt,
StartedAt: record.StartedAt,
}
}
func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState {
phase := firecracker.PhaseStopped
switch record.Phase {
case contracthost.MachinePhaseRunning:
phase = firecracker.PhaseRunning
case contracthost.MachinePhaseFailed:
phase = firecracker.PhaseFailed
}
return firecracker.MachineState{
ID: firecracker.MachineID(record.ID),
Phase: phase,
PID: record.PID,
RuntimeHost: record.RuntimeHost,
SocketPath: record.SocketPath,
TapName: record.TapDevice,
StartedAt: record.StartedAt,
Error: record.Error,
}
}
func validateArtifactRef(ref contracthost.ArtifactRef) error {
if err := validateDownloadURL("artifact.kernel_image_url", ref.KernelImageURL); err != nil {
return err
}
if err := validateDownloadURL("artifact.rootfs_url", ref.RootFSURL); err != nil {
return err
}
return nil
}
func validateMachineID(machineID contracthost.MachineID) error {
value := strings.TrimSpace(string(machineID))
if value == "" {
return fmt.Errorf("machine_id is required")
}
if filepath.Base(value) != value {
return fmt.Errorf("machine_id %q must not contain path separators", machineID)
}
return nil
}
func validateDownloadURL(field string, raw string) error {
value := strings.TrimSpace(raw)
if value == "" {
return fmt.Errorf("%s is required", field)
}
parsed, err := url.Parse(value)
if err != nil {
return fmt.Errorf("%s is invalid: %w", field, err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return fmt.Errorf("%s must use http or https", field)
}
if strings.TrimSpace(parsed.Host) == "" {
return fmt.Errorf("%s host is required", field)
}
return nil
}
func syncDir(path string) error {
dir, err := os.Open(path)
if err != nil {
return fmt.Errorf("open dir %q: %w", path, err)
}
if err := dir.Sync(); err != nil {
dir.Close()
return fmt.Errorf("sync dir %q: %w", path, err)
}
if err := dir.Close(); err != nil {
return fmt.Errorf("close dir %q: %w", path, err)
}
return nil
}

View file

@ -0,0 +1,94 @@
package daemon
import (
"bytes"
"io"
"os"
"path/filepath"
"syscall"
"testing"
)
func TestCloneFilePreservesSparseDiskUsage(t *testing.T) {
root := t.TempDir()
sourcePath := filepath.Join(root, "source.img")
targetPath := filepath.Join(root, "target.img")
sourceFile, err := os.OpenFile(sourcePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
t.Fatalf("open source file: %v", err)
}
if _, err := sourceFile.Write([]byte("head")); err != nil {
sourceFile.Close()
t.Fatalf("write source prefix: %v", err)
}
if _, err := sourceFile.Seek(32<<20, io.SeekStart); err != nil {
sourceFile.Close()
t.Fatalf("seek source hole: %v", err)
}
if _, err := sourceFile.Write([]byte("tail")); err != nil {
sourceFile.Close()
t.Fatalf("write source suffix: %v", err)
}
if err := sourceFile.Close(); err != nil {
t.Fatalf("close source file: %v", err)
}
sourceInfo, err := os.Stat(sourcePath)
if err != nil {
t.Fatalf("stat source file: %v", err)
}
sourceUsage, err := allocatedBytes(sourcePath)
if err != nil {
t.Fatalf("allocated bytes for source: %v", err)
}
if sourceUsage >= sourceInfo.Size()/2 {
t.Skip("temp filesystem does not expose sparse allocation savings")
}
if err := cloneFile(sourcePath, targetPath); err != nil {
t.Fatalf("clone sparse file: %v", err)
}
targetInfo, err := os.Stat(targetPath)
if err != nil {
t.Fatalf("stat target file: %v", err)
}
if targetInfo.Size() != sourceInfo.Size() {
t.Fatalf("target size mismatch: got %d want %d", targetInfo.Size(), sourceInfo.Size())
}
targetUsage, err := allocatedBytes(targetPath)
if err != nil {
t.Fatalf("allocated bytes for target: %v", err)
}
if targetUsage >= targetInfo.Size()/2 {
t.Fatalf("target file is not sparse enough: allocated=%d size=%d", targetUsage, targetInfo.Size())
}
targetData, err := os.ReadFile(targetPath)
if err != nil {
t.Fatalf("read target file: %v", err)
}
if !bytes.Equal(targetData[:4], []byte("head")) {
t.Fatalf("target prefix mismatch: %q", string(targetData[:4]))
}
if !bytes.Equal(targetData[len(targetData)-4:], []byte("tail")) {
t.Fatalf("target suffix mismatch: %q", string(targetData[len(targetData)-4:]))
}
if !bytes.Equal(targetData[4:4+(1<<20)], make([]byte, 1<<20)) {
t.Fatal("target hole contents were not zeroed")
}
}
func allocatedBytes(path string) (int64, error) {
info, err := os.Stat(path)
if err != nil {
return 0, err
}
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, syscall.EINVAL
}
return stat.Blocks * 512, nil
}

View file

@ -0,0 +1,327 @@
package daemon
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) {
record, err := d.reconcileMachine(ctx, id)
if err != nil {
return nil, err
}
return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil
}
func (d *Daemon) ListMachines(ctx context.Context) (*contracthost.ListMachinesResponse, error) {
records, err := d.store.ListMachines(ctx)
if err != nil {
return nil, err
}
machines := make([]contracthost.Machine, 0, len(records))
for _, record := range records {
reconciled, err := d.reconcileMachine(ctx, record.ID)
if err != nil {
return nil, err
}
machines = append(machines, machineToContract(*reconciled))
}
return &contracthost.ListMachinesResponse{Machines: machines}, nil
}
func (d *Daemon) StopMachine(ctx context.Context, id contracthost.MachineID) error {
unlock := d.lockMachine(id)
defer unlock()
record, err := d.store.GetMachine(ctx, id)
if err != nil {
return err
}
if record.Phase == contracthost.MachinePhaseStopped {
return nil
}
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
MachineID: id,
Type: model.MachineOperationStop,
StartedAt: time.Now().UTC(),
}); err != nil {
return err
}
clearOperation := false
defer func() {
if clearOperation {
_ = d.store.DeleteOperation(context.Background(), id)
}
}()
if err := d.stopMachineRecord(ctx, record); err != nil {
return err
}
clearOperation = true
return nil
}
func (d *Daemon) DeleteMachine(ctx context.Context, id contracthost.MachineID) error {
unlock := d.lockMachine(id)
defer unlock()
record, err := d.store.GetMachine(ctx, id)
if err == store.ErrNotFound {
return nil
}
if err != nil {
return err
}
if err := d.store.UpsertOperation(ctx, model.OperationRecord{
MachineID: id,
Type: model.MachineOperationDelete,
StartedAt: time.Now().UTC(),
}); err != nil {
return err
}
clearOperation := false
defer func() {
if clearOperation {
_ = d.store.DeleteOperation(context.Background(), id)
}
}()
if err := d.deleteMachineRecord(ctx, record); err != nil {
return err
}
clearOperation = true
return nil
}
func (d *Daemon) Reconcile(ctx context.Context) error {
operations, err := d.store.ListOperations(ctx)
if err != nil {
return err
}
for _, operation := range operations {
switch operation.Type {
case model.MachineOperationCreate:
if err := d.reconcileCreate(ctx, operation.MachineID); err != nil {
return err
}
case model.MachineOperationStop:
if err := d.reconcileStop(ctx, operation.MachineID); err != nil {
return err
}
case model.MachineOperationDelete:
if err := d.reconcileDelete(ctx, operation.MachineID); err != nil {
return err
}
default:
return fmt.Errorf("unsupported operation type %q", operation.Type)
}
}
records, err := d.store.ListMachines(ctx)
if err != nil {
return err
}
for _, record := range records {
if _, err := d.reconcileMachine(ctx, record.ID); err != nil {
return err
}
}
return nil
}
func (d *Daemon) listRunningNetworks(ctx context.Context, ignore contracthost.MachineID) ([]firecracker.NetworkAllocation, error) {
records, err := d.store.ListMachines(ctx)
if err != nil {
return nil, err
}
networks := make([]firecracker.NetworkAllocation, 0, len(records))
for _, record := range records {
if record.ID == ignore || record.Phase != contracthost.MachinePhaseRunning {
continue
}
if strings.TrimSpace(record.RuntimeHost) == "" || strings.TrimSpace(record.TapDevice) == "" {
continue
}
network, err := firecracker.AllocationFromGuestIP(record.RuntimeHost, record.TapDevice)
if err != nil {
return nil, err
}
networks = append(networks, network)
}
return networks, nil
}
func (d *Daemon) reconcileCreate(ctx context.Context, machineID contracthost.MachineID) error {
_, err := d.store.GetMachine(ctx, machineID)
if err == nil {
if _, err := d.reconcileMachine(ctx, machineID); err != nil {
return err
}
return d.store.DeleteOperation(ctx, machineID)
}
if err != store.ErrNotFound {
return err
}
if err := os.Remove(d.systemVolumePath(machineID)); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("cleanup system volume for %q: %w", machineID, err)
}
if err := d.store.DeleteVolume(ctx, d.systemVolumeID(machineID)); err != nil && err != store.ErrNotFound {
return err
}
if err := d.detachVolumesForMachine(ctx, machineID); err != nil {
return err
}
_ = os.RemoveAll(filepath.Dir(d.systemVolumePath(machineID)))
if err := os.RemoveAll(d.machineRuntimeBaseDir(machineID)); err != nil {
return fmt.Errorf("cleanup runtime dir for %q: %w", machineID, err)
}
return d.store.DeleteOperation(ctx, machineID)
}
func (d *Daemon) reconcileStop(ctx context.Context, machineID contracthost.MachineID) error {
record, err := d.store.GetMachine(ctx, machineID)
if err == store.ErrNotFound {
return d.store.DeleteOperation(ctx, machineID)
}
if err != nil {
return err
}
if err := d.stopMachineRecord(ctx, record); err != nil {
return err
}
return d.store.DeleteOperation(ctx, machineID)
}
func (d *Daemon) reconcileDelete(ctx context.Context, machineID contracthost.MachineID) error {
record, err := d.store.GetMachine(ctx, machineID)
if err == store.ErrNotFound {
if err := os.Remove(d.systemVolumePath(machineID)); err != nil && !os.IsNotExist(err) {
return err
}
if err := d.store.DeleteVolume(ctx, d.systemVolumeID(machineID)); err != nil && err != store.ErrNotFound {
return err
}
if err := d.detachVolumesForMachine(ctx, machineID); err != nil {
return err
}
_ = os.RemoveAll(filepath.Dir(d.systemVolumePath(machineID)))
_ = os.RemoveAll(d.machineRuntimeBaseDir(machineID))
return d.store.DeleteOperation(ctx, machineID)
}
if err != nil {
return err
}
if err := d.deleteMachineRecord(ctx, record); err != nil {
return err
}
return d.store.DeleteOperation(ctx, machineID)
}
func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.MachineID) (*model.MachineRecord, error) {
unlock := d.lockMachine(machineID)
defer unlock()
record, err := d.store.GetMachine(ctx, machineID)
if err != nil {
return nil, err
}
if record.Phase != contracthost.MachinePhaseRunning {
return record, nil
}
state, err := d.runtime.Inspect(machineToRuntimeState(*record))
if err != nil {
return nil, err
}
if state.Phase == firecracker.PhaseRunning {
return record, nil
}
if err := d.runtime.Delete(ctx, *state); err != nil {
return nil, err
}
record.Phase = contracthost.MachinePhaseFailed
record.Error = state.Error
record.PID = 0
record.SocketPath = ""
record.RuntimeHost = ""
record.TapDevice = ""
record.StartedAt = nil
if err := d.store.UpdateMachine(ctx, *record); err != nil {
return nil, err
}
return record, nil
}
func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error {
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
return err
}
if err := d.detachVolumesForMachine(ctx, record.ID); err != nil {
return err
}
systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID)
if err != nil {
return err
}
if err := os.Remove(systemVolume.Path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove system volume %q: %w", systemVolume.Path, err)
}
if err := os.RemoveAll(filepath.Dir(systemVolume.Path)); err != nil {
return fmt.Errorf("remove machine disk dir %q: %w", filepath.Dir(systemVolume.Path), err)
}
if err := d.store.DeleteVolume(ctx, record.SystemVolumeID); err != nil {
return err
}
return d.store.DeleteMachine(ctx, record.ID)
}
func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error {
if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil {
return err
}
record.Phase = contracthost.MachinePhaseStopped
record.Error = ""
record.PID = 0
record.SocketPath = ""
record.RuntimeHost = ""
record.TapDevice = ""
record.StartedAt = nil
return d.store.UpdateMachine(ctx, *record)
}
func (d *Daemon) detachVolumesForMachine(ctx context.Context, machineID contracthost.MachineID) error {
volumes, err := d.store.ListVolumes(ctx)
if err != nil {
return err
}
for _, volume := range volumes {
if volume.AttachedMachineID == nil || *volume.AttachedMachineID != machineID {
continue
}
volume.AttachedMachineID = nil
if err := d.store.UpdateVolume(ctx, volume); err != nil {
return err
}
}
return nil
}

View file

@ -14,7 +14,7 @@ import (
const ( const (
defaultCgroupVersion = "2" defaultCgroupVersion = "2"
defaultFirecrackerInitTimeout = 3 * time.Second defaultFirecrackerInitTimeout = 10 * time.Second
defaultFirecrackerPollInterval = 10 * time.Millisecond defaultFirecrackerPollInterval = 10 * time.Millisecond
defaultRootDriveID = "root_drive" defaultRootDriveID = "root_drive"
defaultVSockRunDir = "/run" defaultVSockRunDir = "/run"
@ -120,20 +120,34 @@ func waitForSocket(ctx context.Context, client *apiClient, socketPath string) er
ticker := time.NewTicker(defaultFirecrackerPollInterval) ticker := time.NewTicker(defaultFirecrackerPollInterval)
defer ticker.Stop() defer ticker.Stop()
var lastStatErr error
var lastPingErr error
for { for {
select { select {
case <-waitContext.Done(): case <-waitContext.Done():
return waitContext.Err() switch {
case lastPingErr != nil:
return fmt.Errorf("%w (socket=%q last_ping_err=%v)", waitContext.Err(), socketPath, lastPingErr)
case lastStatErr != nil:
return fmt.Errorf("%w (socket=%q last_stat_err=%v)", waitContext.Err(), socketPath, lastStatErr)
default:
return fmt.Errorf("%w (socket=%q)", waitContext.Err(), socketPath)
}
case <-ticker.C: case <-ticker.C:
if _, err := os.Stat(socketPath); err != nil { if _, err := os.Stat(socketPath); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
lastStatErr = err
continue continue
} }
return fmt.Errorf("stat socket %q: %w", socketPath, err) return fmt.Errorf("stat socket %q: %w", socketPath, err)
} }
lastStatErr = nil
if err := client.Ping(waitContext); err != nil { if err := client.Ping(waitContext); err != nil {
lastPingErr = err
continue continue
} }
lastPingErr = nil
return nil return nil
} }
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strings" "strings"
@ -47,6 +48,30 @@ func (n NetworkAllocation) GuestIP() netip.Addr {
return n.GuestCIDR.Addr() return n.GuestCIDR.Addr()
} }
// AllocationFromGuestIP reconstructs the host-side allocation from a guest IP and tap name.
func AllocationFromGuestIP(guestIP string, tapName string) (NetworkAllocation, error) {
parsed := net.ParseIP(strings.TrimSpace(guestIP))
if parsed == nil {
return NetworkAllocation{}, fmt.Errorf("parse guest ip %q", guestIP)
}
addr, ok := netip.AddrFromSlice(parsed.To4())
if !ok {
return NetworkAllocation{}, fmt.Errorf("guest ip %q must be IPv4", guestIP)
}
base := ipv4ToUint32(addr) - 2
hostIP := uint32ToIPv4(base + 1)
guest := uint32ToIPv4(base + 2)
return NetworkAllocation{
InterfaceID: defaultInterfaceID,
TapName: strings.TrimSpace(tapName),
HostCIDR: netip.PrefixFrom(hostIP, defaultNetworkPrefixBits),
GuestCIDR: netip.PrefixFrom(guest, defaultNetworkPrefixBits),
GatewayIP: hostIP,
GuestMAC: macForIPv4(guest),
}, nil
}
// NewNetworkAllocator returns a new /30 allocator rooted at the provided IPv4 prefix. // NewNetworkAllocator returns a new /30 allocator rooted at the provided IPv4 prefix.
func NewNetworkAllocator(cidr string) (*NetworkAllocator, error) { func NewNetworkAllocator(cidr string) (*NetworkAllocator, error) {
cidr = strings.TrimSpace(cidr) cidr = strings.TrimSpace(cidr)

View file

@ -3,6 +3,7 @@ package firecracker
import ( import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
) )
@ -45,3 +46,7 @@ func buildMachinePaths(rootDir string, id MachineID, firecrackerBinaryPath strin
SocketPath: filepath.Join(chrootRootDir, defaultFirecrackerSocketDir, defaultFirecrackerSocketName), SocketPath: filepath.Join(chrootRootDir, defaultFirecrackerSocketDir, defaultFirecrackerSocketName),
}, nil }, nil
} }
func procSocketPath(pid int) string {
return filepath.Join("/proc", strconv.Itoa(pid), "root", defaultFirecrackerSocketDir, defaultFirecrackerSocketName)
}

View file

@ -8,7 +8,6 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
) )
@ -27,26 +26,9 @@ type Runtime struct {
jailerBinaryPath string jailerBinaryPath string
networkAllocator *NetworkAllocator networkAllocator *NetworkAllocator
networkProvisioner NetworkProvisioner networkProvisioner NetworkProvisioner
mu sync.RWMutex
machines map[MachineID]*managedMachine
} }
type managedMachine struct { const debugPreserveFailureEnv = "FIRECRACKER_DEBUG_PRESERVE_FAILURE"
cmd *exec.Cmd
entered bool
exited chan struct{}
network NetworkAllocation
paths machinePaths
spec MachineSpec
state MachineState
stopping bool
}
const (
defaultVSockCIDStart = uint32(3)
defaultVSockID = "vsock0"
)
func NewRuntime(cfg RuntimeConfig) (*Runtime, error) { func NewRuntime(cfg RuntimeConfig) (*Runtime, error) {
rootDir := filepath.Clean(strings.TrimSpace(cfg.RootDir)) rootDir := filepath.Clean(strings.TrimSpace(cfg.RootDir))
@ -79,64 +61,28 @@ func NewRuntime(cfg RuntimeConfig) (*Runtime, error) {
jailerBinaryPath: jailerBinaryPath, jailerBinaryPath: jailerBinaryPath,
networkAllocator: allocator, networkAllocator: allocator,
networkProvisioner: NewIPTapProvisioner(), networkProvisioner: NewIPTapProvisioner(),
machines: make(map[MachineID]*managedMachine),
}, nil }, nil
} }
func (r *Runtime) Boot(ctx context.Context, spec MachineSpec) (*MachineState, error) { func (r *Runtime) Boot(ctx context.Context, spec MachineSpec, usedNetworks []NetworkAllocation) (*MachineState, error) {
if err := spec.Validate(); err != nil { if err := spec.Validate(); err != nil {
return nil, err return nil, err
} }
r.mu.Lock()
if _, exists := r.machines[spec.ID]; exists {
r.mu.Unlock()
return nil, fmt.Errorf("machine %q already exists", spec.ID)
}
usedNetworks := make([]NetworkAllocation, 0, len(r.machines))
usedVSockCIDs := make(map[uint32]struct{}, len(r.machines))
for _, machine := range r.machines {
if machine == nil {
continue
}
usedNetworks = append(usedNetworks, machine.network)
if machine.spec.Vsock != nil {
usedVSockCIDs[machine.spec.Vsock.CID] = struct{}{}
}
}
spec, err := r.resolveVSock(spec, usedVSockCIDs)
if err != nil {
r.mu.Unlock()
return nil, err
}
r.machines[spec.ID] = &managedMachine{
spec: spec,
state: MachineState{
ID: spec.ID,
Phase: PhaseProvisioning,
},
entered: true,
}
r.mu.Unlock()
cleanup := func(network NetworkAllocation, paths machinePaths, command *exec.Cmd) { cleanup := func(network NetworkAllocation, paths machinePaths, command *exec.Cmd) {
if preserveFailureArtifacts() {
fmt.Fprintf(os.Stderr, "firecracker debug: preserving failure artifacts machine=%s pid=%d socket=%s base=%s\n", spec.ID, pidOf(command), paths.SocketPath, paths.BaseDir)
return
}
cleanupStartedProcess(command) cleanupStartedProcess(command)
_ = r.networkProvisioner.Remove(context.Background(), network) _ = r.networkProvisioner.Remove(context.Background(), network)
_ = removeIfExists(hostVSockPath(paths, spec))
if paths.BaseDir != "" { if paths.BaseDir != "" {
_ = os.RemoveAll(paths.BaseDir) _ = os.RemoveAll(paths.BaseDir)
} }
r.mu.Lock()
delete(r.machines, spec.ID)
r.mu.Unlock()
} }
network, err := r.networkAllocator.Allocate(usedNetworks) network, err := r.networkAllocator.Allocate(usedNetworks)
if err != nil { if err != nil {
cleanup(NetworkAllocation{}, machinePaths{}, nil)
return nil, err return nil, err
} }
@ -159,9 +105,14 @@ func (r *Runtime) Boot(ctx context.Context, spec MachineSpec) (*MachineState, er
cleanup(network, paths, nil) cleanup(network, paths, nil)
return nil, err return nil, err
} }
socketPath := paths.SocketPath
if pid := pidOf(command); pid > 0 {
socketPath = procSocketPath(pid)
}
fmt.Fprintf(os.Stderr, "firecracker debug: launched machine=%s pid=%d socket=%s jailer_base=%s\n", spec.ID, pidOf(command), socketPath, paths.JailerBaseDir)
client := newAPIClient(paths.SocketPath) client := newAPIClient(socketPath)
if err := waitForSocket(ctx, client, paths.SocketPath); err != nil { if err := waitForSocket(ctx, client, socketPath); err != nil {
cleanup(network, paths, command) cleanup(network, paths, command)
return nil, fmt.Errorf("wait for firecracker socket: %w", err) return nil, fmt.Errorf("wait for firecracker socket: %w", err)
} }
@ -187,162 +138,74 @@ func (r *Runtime) Boot(ctx context.Context, spec MachineSpec) (*MachineState, er
Phase: PhaseRunning, Phase: PhaseRunning,
PID: pid, PID: pid,
RuntimeHost: network.GuestIP().String(), RuntimeHost: network.GuestIP().String(),
SocketPath: paths.SocketPath, SocketPath: socketPath,
TapName: network.TapName, TapName: network.TapName,
StartedAt: &now, StartedAt: &now,
} }
r.mu.Lock()
entry := r.machines[spec.ID]
entry.cmd = command
entry.exited = make(chan struct{})
entry.network = network
entry.paths = paths
entry.state = state
r.mu.Unlock()
go r.watchMachine(spec.ID, command, entry.exited)
out := state
return &out, nil
}
func (r *Runtime) Inspect(id MachineID) (*MachineState, error) {
r.mu.RLock()
entry, ok := r.machines[id]
r.mu.RUnlock()
if !ok || entry == nil {
return nil, ErrMachineNotFound
}
state := entry.state
if state.PID > 0 && !processExists(state.PID) {
state.Phase = PhaseStopped
state.PID = 0
}
return &state, nil return &state, nil
} }
func (r *Runtime) Stop(ctx context.Context, id MachineID) error { func (r *Runtime) Inspect(state MachineState) (*MachineState, error) {
r.mu.RLock() if state.PID > 0 && !processExists(state.PID) {
entry, ok := r.machines[id] state.Phase = PhaseFailed
r.mu.RUnlock() state.PID = 0
if !ok || entry == nil { state.Error = "firecracker process not found"
return ErrMachineNotFound
} }
if entry.cmd == nil || entry.cmd.Process == nil { return &state, nil
return fmt.Errorf("machine %q has no firecracker process", id) }
}
if entry.state.Phase == PhaseStopped { func (r *Runtime) Stop(ctx context.Context, state MachineState) error {
if state.PID < 1 || !processExists(state.PID) {
return nil return nil
} }
r.mu.Lock() process, err := os.FindProcess(state.PID)
entry.stopping = true if err != nil {
process := entry.cmd.Process return fmt.Errorf("find process for machine %q: %w", state.ID, err)
exited := entry.exited }
r.mu.Unlock()
if err := process.Signal(syscall.SIGTERM); err != nil && !errors.Is(err, os.ErrProcessDone) { if err := process.Signal(syscall.SIGTERM); err != nil && !errors.Is(err, os.ErrProcessDone) {
return fmt.Errorf("stop machine %q: %w", id, err) return fmt.Errorf("stop machine %q: %w", state.ID, err)
} }
select { ticker := time.NewTicker(50 * time.Millisecond)
case <-exited: defer ticker.Stop()
for {
if !processExists(state.PID) {
return nil return nil
}
select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-ticker.C:
}
} }
} }
func (r *Runtime) Delete(ctx context.Context, id MachineID) error { func (r *Runtime) Delete(ctx context.Context, state MachineState) error {
r.mu.RLock() if err := r.Stop(ctx, state); err != nil {
entry, ok := r.machines[id] return err
r.mu.RUnlock() }
if !ok || entry == nil { if strings.TrimSpace(state.RuntimeHost) != "" && strings.TrimSpace(state.TapName) != "" {
return ErrMachineNotFound network, err := AllocationFromGuestIP(state.RuntimeHost, state.TapName)
if err != nil {
return err
}
if err := r.networkProvisioner.Remove(ctx, network); err != nil {
return err
}
} }
if entry.state.Phase == PhaseRunning { paths, err := buildMachinePaths(r.rootDir, state.ID, r.firecrackerBinaryPath)
if err := r.Stop(ctx, id); err != nil && !errors.Is(err, context.Canceled) { if err != nil {
return err return err
} }
if err := os.RemoveAll(paths.BaseDir); err != nil {
return fmt.Errorf("remove machine dir %q: %w", paths.BaseDir, err)
} }
if err := r.networkProvisioner.Remove(ctx, entry.network); err != nil {
return err
}
if err := removeIfExists(hostVSockPath(entry.paths, entry.spec)); err != nil {
return err
}
if err := os.RemoveAll(entry.paths.BaseDir); err != nil {
return fmt.Errorf("remove machine dir %q: %w", entry.paths.BaseDir, err)
}
r.mu.Lock()
delete(r.machines, id)
r.mu.Unlock()
return nil return nil
} }
func (r *Runtime) watchMachine(id MachineID, command *exec.Cmd, exited chan struct{}) {
err := command.Wait()
close(exited)
r.mu.Lock()
defer r.mu.Unlock()
entry, ok := r.machines[id]
if !ok || entry == nil || entry.cmd != command {
return
}
entry.state.PID = 0
if entry.stopping {
entry.state.Phase = PhaseStopped
entry.state.Error = ""
entry.stopping = false
return
}
if err != nil {
entry.state.Phase = PhaseError
entry.state.Error = err.Error()
return
}
entry.state.Phase = PhaseStopped
entry.state.Error = ""
}
func (r *Runtime) resolveVSock(spec MachineSpec, used map[uint32]struct{}) (MachineSpec, error) {
if spec.Vsock != nil {
if _, exists := used[spec.Vsock.CID]; exists {
return MachineSpec{}, fmt.Errorf("vsock cid %d already in use", spec.Vsock.CID)
}
return spec, nil
}
cid, err := nextVSockCID(used)
if err != nil {
return MachineSpec{}, err
}
spec.Vsock = &VsockSpec{
ID: defaultVSockID,
CID: cid,
Path: string(spec.ID) + ".sock",
}
return spec, nil
}
func nextVSockCID(used map[uint32]struct{}) (uint32, error) {
for cid := defaultVSockCIDStart; cid != 0; cid++ {
if _, exists := used[cid]; !exists {
return cid, nil
}
}
return 0, fmt.Errorf("vsock cid space exhausted")
}
func processExists(pid int) bool { func processExists(pid int) bool {
if pid < 1 { if pid < 1 {
return false return false
@ -351,12 +214,19 @@ func processExists(pid int) bool {
return err == nil || err == syscall.EPERM return err == nil || err == syscall.EPERM
} }
func removeIfExists(path string) error { func pidOf(command *exec.Cmd) int {
if path == "" { if command == nil || command.Process == nil {
return nil return 0
}
return command.Process.Pid
}
func preserveFailureArtifacts() bool {
value := strings.TrimSpace(os.Getenv(debugPreserveFailureEnv))
switch strings.ToLower(value) {
case "1", "true", "yes", "on":
return true
default:
return false
} }
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("remove path %q: %w", path, err)
}
return nil
} }

View file

@ -18,14 +18,10 @@ type MachineState struct {
} }
const ( const (
// PhaseProvisioning means host-local resources are still being prepared.
PhaseProvisioning Phase = "provisioning"
// PhaseRunning means the Firecracker process is live. // PhaseRunning means the Firecracker process is live.
PhaseRunning Phase = "running" PhaseRunning Phase = "running"
// PhaseStopped means the VM is no longer running. // PhaseStopped means the VM is no longer running.
PhaseStopped Phase = "stopped" PhaseStopped Phase = "stopped"
// PhaseMissing means the machine is not known to the runtime. // PhaseFailed means the runtime observed a terminal failure.
PhaseMissing Phase = "missing" PhaseFailed Phase = "failed"
// PhaseError means the runtime observed a terminal failure.
PhaseError Phase = "error"
) )

View file

@ -2,6 +2,10 @@ package httpapi
import ( import (
"context" "context"
"encoding/json"
"fmt"
"net/http"
"strings"
contracthost "github.com/getcompanion-ai/computer-host/contract" contracthost "github.com/getcompanion-ai/computer-host/contract"
) )
@ -14,3 +18,133 @@ type Service interface {
DeleteMachine(context.Context, contracthost.MachineID) error DeleteMachine(context.Context, contracthost.MachineID) error
Health(context.Context) (*contracthost.HealthResponse, error) Health(context.Context) (*contracthost.HealthResponse, error)
} }
type Handler struct {
service Service
}
func New(service Service) (*Handler, error) {
if service == nil {
return nil, fmt.Errorf("service is required")
}
return &Handler{service: service}, nil
}
func (h *Handler) Routes() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/health", h.handleHealth)
mux.HandleFunc("/machines", h.handleMachines)
mux.HandleFunc("/machines/", h.handleMachine)
return mux
}
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeMethodNotAllowed(w)
return
}
response, err := h.service.Health(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, err)
return
}
writeJSON(w, http.StatusOK, response)
}
func (h *Handler) handleMachines(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
response, err := h.service.ListMachines(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, err)
return
}
writeJSON(w, http.StatusOK, response)
case http.MethodPost:
var request contracthost.CreateMachineRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
response, err := h.service.CreateMachine(r.Context(), request)
if err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
writeJSON(w, http.StatusCreated, response)
default:
writeMethodNotAllowed(w)
}
}
func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/machines/")
if path == "" {
writeError(w, http.StatusNotFound, fmt.Errorf("machine id is required"))
return
}
parts := strings.Split(path, "/")
machineID := contracthost.MachineID(parts[0])
if len(parts) == 1 {
switch r.Method {
case http.MethodGet:
response, err := h.service.GetMachine(r.Context(), machineID)
if err != nil {
writeError(w, statusForError(err), err)
return
}
writeJSON(w, http.StatusOK, response)
case http.MethodDelete:
if err := h.service.DeleteMachine(r.Context(), machineID); err != nil {
writeError(w, statusForError(err), err)
return
}
w.WriteHeader(http.StatusNoContent)
default:
writeMethodNotAllowed(w)
}
return
}
if len(parts) == 2 && parts[1] == "stop" {
if r.Method != http.MethodPost {
writeMethodNotAllowed(w)
return
}
if err := h.service.StopMachine(r.Context(), machineID); err != nil {
writeError(w, statusForError(err), err)
return
}
w.WriteHeader(http.StatusNoContent)
return
}
writeError(w, http.StatusNotFound, fmt.Errorf("route not found"))
}
func statusForError(err error) int {
message := strings.ToLower(err.Error())
switch {
case strings.Contains(message, "not found"):
return http.StatusNotFound
case strings.Contains(message, "already exists"):
return http.StatusConflict
default:
return http.StatusBadRequest
}
}
func writeError(w http.ResponseWriter, status int, err error) {
writeJSON(w, status, map[string]string{"error": err.Error()})
}
func writeJSON(w http.ResponseWriter, status int, value any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(value)
}
func writeMethodNotAllowed(w http.ResponseWriter) {
writeError(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed"))
}

View file

@ -17,6 +17,8 @@ const (
type ArtifactRecord struct { type ArtifactRecord struct {
Ref contracthost.ArtifactRef Ref contracthost.ArtifactRef
LocalKey string
LocalDir string
KernelImagePath string KernelImagePath string
RootFSPath string RootFSPath string
CreatedAt time.Time CreatedAt time.Time
@ -32,6 +34,8 @@ type MachineRecord struct {
Ports []contracthost.MachinePort Ports []contracthost.MachinePort
Phase contracthost.MachinePhase Phase contracthost.MachinePhase
Error string Error string
PID int
SocketPath string
CreatedAt time.Time CreatedAt time.Time
StartedAt *time.Time StartedAt *time.Time
} }
@ -45,3 +49,17 @@ type VolumeRecord struct {
Path string Path string
CreatedAt time.Time CreatedAt time.Time
} }
type MachineOperation string
const (
MachineOperationCreate MachineOperation = "create"
MachineOperationStop MachineOperation = "stop"
MachineOperationDelete MachineOperation = "delete"
)
type OperationRecord struct {
MachineID contracthost.MachineID
Type MachineOperation
StartedAt time.Time
}

View file

@ -1,98 +0,0 @@
package service
import (
"context"
"fmt"
"strings"
appconfig "github.com/getcompanion-ai/computer-host/internal/config"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
)
// MachineRuntime is the minimum runtime surface the host-local service needs.
type MachineRuntime interface {
Boot(context.Context, firecracker.MachineSpec) (*firecracker.MachineState, error)
Inspect(firecracker.MachineID) (*firecracker.MachineState, error)
Stop(context.Context, firecracker.MachineID) error
Delete(context.Context, firecracker.MachineID) error
}
// Service manages local machine lifecycle requests on a single host.
type Service struct {
config appconfig.Config
runtime MachineRuntime
}
// CreateMachineRequest contains the minimum machine creation inputs above the raw runtime layer.
type CreateMachineRequest struct {
ID firecracker.MachineID
}
const (
defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off"
defaultGuestMemoryMiB = int64(512)
defaultGuestVCPUs = int64(1)
)
// New constructs a new host-local service from the app config.
func New(cfg appconfig.Config) (*Service, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
runtime, err := firecracker.NewRuntime(cfg.FirecrackerRuntimeConfig())
if err != nil {
return nil, err
}
return &Service{
config: cfg,
runtime: runtime,
}, nil
}
// CreateMachine boots a new local machine from the single supported host default shape.
func (s *Service) CreateMachine(ctx context.Context, req CreateMachineRequest) (*firecracker.MachineState, error) {
spec, err := s.buildMachineSpec(req)
if err != nil {
return nil, err
}
return s.runtime.Boot(ctx, spec)
}
// GetMachine returns the current local state for a machine.
func (s *Service) GetMachine(id firecracker.MachineID) (*firecracker.MachineState, error) {
return s.runtime.Inspect(id)
}
// StopMachine stops a running local machine.
func (s *Service) StopMachine(ctx context.Context, id firecracker.MachineID) error {
return s.runtime.Stop(ctx, id)
}
// DeleteMachine removes a local machine and its host-local resources.
func (s *Service) DeleteMachine(ctx context.Context, id firecracker.MachineID) error {
return s.runtime.Delete(ctx, id)
}
func (s *Service) buildMachineSpec(req CreateMachineRequest) (firecracker.MachineSpec, error) {
if s == nil {
return firecracker.MachineSpec{}, fmt.Errorf("service is required")
}
if strings.TrimSpace(string(req.ID)) == "" {
return firecracker.MachineSpec{}, fmt.Errorf("machine id is required")
}
spec := firecracker.MachineSpec{
ID: req.ID,
VCPUs: defaultGuestVCPUs,
MemoryMiB: defaultGuestMemoryMiB,
KernelImagePath: s.config.KernelImagePath,
RootFSPath: s.config.RootFSPath,
KernelArgs: defaultGuestKernelArgs,
}
if err := spec.Validate(); err != nil {
return firecracker.MachineSpec{}, err
}
return spec, nil
}

View file

@ -0,0 +1,413 @@
package store
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type FileStore struct {
mu sync.Mutex
statePath string
operationsPath string
}
type persistedOperations struct {
Operations []model.OperationRecord `json:"operations"`
}
type persistedState struct {
Artifacts []model.ArtifactRecord `json:"artifacts"`
Machines []model.MachineRecord `json:"machines"`
Volumes []model.VolumeRecord `json:"volumes"`
}
func NewFileStore(statePath string, operationsPath string) (*FileStore, error) {
store := &FileStore{
statePath: filepath.Clean(statePath),
operationsPath: filepath.Clean(operationsPath),
}
if err := initializeJSONFile(store.statePath, emptyPersistedState()); err != nil {
return nil, err
}
if err := initializeJSONFile(store.operationsPath, emptyPersistedOperations()); err != nil {
return nil, err
}
return store, nil
}
func (s *FileStore) PutArtifact(_ context.Context, record model.ArtifactRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for i := range state.Artifacts {
if state.Artifacts[i].Ref == record.Ref {
state.Artifacts[i] = record
return nil
}
}
state.Artifacts = append(state.Artifacts, record)
return nil
})
}
func (s *FileStore) GetArtifact(_ context.Context, ref contracthost.ArtifactRef) (*model.ArtifactRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, err := s.readState()
if err != nil {
return nil, err
}
for i := range state.Artifacts {
if state.Artifacts[i].Ref == ref {
record := state.Artifacts[i]
return &record, nil
}
}
return nil, ErrNotFound
}
func (s *FileStore) ListArtifacts(_ context.Context) ([]model.ArtifactRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, err := s.readState()
if err != nil {
return nil, err
}
return append([]model.ArtifactRecord(nil), state.Artifacts...), nil
}
func (s *FileStore) CreateMachine(_ context.Context, record model.MachineRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for _, machine := range state.Machines {
if machine.ID == record.ID {
return fmt.Errorf("store: machine %q already exists", record.ID)
}
}
state.Machines = append(state.Machines, record)
return nil
})
}
func (s *FileStore) GetMachine(_ context.Context, id contracthost.MachineID) (*model.MachineRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, err := s.readState()
if err != nil {
return nil, err
}
for i := range state.Machines {
if state.Machines[i].ID == id {
record := state.Machines[i]
return &record, nil
}
}
return nil, ErrNotFound
}
func (s *FileStore) ListMachines(_ context.Context) ([]model.MachineRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, err := s.readState()
if err != nil {
return nil, err
}
return append([]model.MachineRecord(nil), state.Machines...), nil
}
func (s *FileStore) UpdateMachine(_ context.Context, record model.MachineRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for i := range state.Machines {
if state.Machines[i].ID == record.ID {
state.Machines[i] = record
return nil
}
}
return ErrNotFound
})
}
func (s *FileStore) DeleteMachine(_ context.Context, id contracthost.MachineID) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for i := range state.Machines {
if state.Machines[i].ID == id {
state.Machines = append(state.Machines[:i], state.Machines[i+1:]...)
return nil
}
}
return ErrNotFound
})
}
func (s *FileStore) CreateVolume(_ context.Context, record model.VolumeRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for _, volume := range state.Volumes {
if volume.ID == record.ID {
return fmt.Errorf("store: volume %q already exists", record.ID)
}
}
state.Volumes = append(state.Volumes, record)
return nil
})
}
func (s *FileStore) GetVolume(_ context.Context, id contracthost.VolumeID) (*model.VolumeRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, err := s.readState()
if err != nil {
return nil, err
}
for i := range state.Volumes {
if state.Volumes[i].ID == id {
record := state.Volumes[i]
return &record, nil
}
}
return nil, ErrNotFound
}
func (s *FileStore) ListVolumes(_ context.Context) ([]model.VolumeRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, err := s.readState()
if err != nil {
return nil, err
}
return append([]model.VolumeRecord(nil), state.Volumes...), nil
}
func (s *FileStore) UpdateVolume(_ context.Context, record model.VolumeRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for i := range state.Volumes {
if state.Volumes[i].ID == record.ID {
state.Volumes[i] = record
return nil
}
}
return ErrNotFound
})
}
func (s *FileStore) DeleteVolume(_ context.Context, id contracthost.VolumeID) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateState(func(state *persistedState) error {
for i := range state.Volumes {
if state.Volumes[i].ID == id {
state.Volumes = append(state.Volumes[:i], state.Volumes[i+1:]...)
return nil
}
}
return ErrNotFound
})
}
func (s *FileStore) UpsertOperation(_ context.Context, record model.OperationRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateOperations(func(operations *persistedOperations) error {
for i := range operations.Operations {
if operations.Operations[i].MachineID == record.MachineID {
operations.Operations[i] = record
return nil
}
}
operations.Operations = append(operations.Operations, record)
return nil
})
}
func (s *FileStore) ListOperations(_ context.Context) ([]model.OperationRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
operations, err := s.readOperations()
if err != nil {
return nil, err
}
return append([]model.OperationRecord(nil), operations.Operations...), nil
}
func (s *FileStore) DeleteOperation(_ context.Context, machineID contracthost.MachineID) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateOperations(func(operations *persistedOperations) error {
for i := range operations.Operations {
if operations.Operations[i].MachineID == machineID {
operations.Operations = append(operations.Operations[:i], operations.Operations[i+1:]...)
return nil
}
}
return nil
})
}
func (s *FileStore) readOperations() (*persistedOperations, error) {
var operations persistedOperations
if err := readJSONFile(s.operationsPath, &operations); err != nil {
return nil, err
}
normalizeOperations(&operations)
return &operations, nil
}
func (s *FileStore) readState() (*persistedState, error) {
var state persistedState
if err := readJSONFile(s.statePath, &state); err != nil {
return nil, err
}
normalizeState(&state)
return &state, nil
}
func (s *FileStore) updateOperations(update func(*persistedOperations) error) error {
operations, err := s.readOperations()
if err != nil {
return err
}
if err := update(operations); err != nil {
return err
}
return writeJSONFileAtomically(s.operationsPath, operations)
}
func (s *FileStore) updateState(update func(*persistedState) error) error {
state, err := s.readState()
if err != nil {
return err
}
if err := update(state); err != nil {
return err
}
return writeJSONFileAtomically(s.statePath, state)
}
func initializeJSONFile(path string, value any) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("create store dir for %q: %w", path, err)
}
if _, err := os.Stat(path); err == nil {
return nil
} else if !os.IsNotExist(err) {
return fmt.Errorf("stat store file %q: %w", path, err)
}
return writeJSONFileAtomically(path, value)
}
func readJSONFile(path string, value any) error {
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read store file %q: %w", path, err)
}
if err := json.Unmarshal(data, value); err != nil {
return fmt.Errorf("decode store file %q: %w", path, err)
}
return nil
}
func writeJSONFileAtomically(path string, value any) error {
payload, err := json.MarshalIndent(value, "", " ")
if err != nil {
return fmt.Errorf("marshal store file %q: %w", path, err)
}
payload = append(payload, '\n')
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("create store dir for %q: %w", path, err)
}
tmpPath := path + ".tmp"
file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open temp store file %q: %w", tmpPath, err)
}
if _, err := file.Write(payload); err != nil {
file.Close()
return fmt.Errorf("write temp store file %q: %w", tmpPath, err)
}
if err := file.Sync(); err != nil {
file.Close()
return fmt.Errorf("sync temp store file %q: %w", tmpPath, err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close temp store file %q: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, path); err != nil {
return fmt.Errorf("rename temp store file %q to %q: %w", tmpPath, path, err)
}
dir, err := os.Open(filepath.Dir(path))
if err != nil {
return fmt.Errorf("open store dir for %q: %w", path, err)
}
if err := dir.Sync(); err != nil {
dir.Close()
return fmt.Errorf("sync store dir for %q: %w", path, err)
}
if err := dir.Close(); err != nil {
return fmt.Errorf("close store dir for %q: %w", path, err)
}
return nil
}
func emptyPersistedState() persistedState {
return persistedState{
Artifacts: []model.ArtifactRecord{},
Machines: []model.MachineRecord{},
Volumes: []model.VolumeRecord{},
}
}
func emptyPersistedOperations() persistedOperations {
return persistedOperations{Operations: []model.OperationRecord{}}
}
func normalizeState(state *persistedState) {
if state.Artifacts == nil {
state.Artifacts = []model.ArtifactRecord{}
}
if state.Machines == nil {
state.Machines = []model.MachineRecord{}
}
if state.Volumes == nil {
state.Volumes = []model.VolumeRecord{}
}
}
func normalizeOperations(operations *persistedOperations) {
if operations.Operations == nil {
operations.Operations = []model.OperationRecord{}
}
}

View file

@ -0,0 +1,130 @@
package store
import (
"context"
"path/filepath"
"testing"
"time"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func TestFileStorePersistsStateAndOperations(t *testing.T) {
t.Parallel()
root := t.TempDir()
statePath := filepath.Join(root, "state", "state.json")
opsPath := filepath.Join(root, "state", "ops.json")
ctx := context.Background()
first, err := NewFileStore(statePath, opsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
artifact := model.ArtifactRecord{
Ref: contracthost.ArtifactRef{
KernelImageURL: "https://example.com/kernel",
RootFSURL: "https://example.com/rootfs",
},
LocalKey: "artifact-key",
LocalDir: filepath.Join(root, "artifacts", "artifact-key"),
KernelImagePath: filepath.Join(root, "artifacts", "artifact-key", "kernel"),
RootFSPath: filepath.Join(root, "artifacts", "artifact-key", "rootfs"),
CreatedAt: time.Unix(1700000000, 0).UTC(),
}
if err := first.PutArtifact(ctx, artifact); err != nil {
t.Fatalf("put artifact: %v", err)
}
machine := model.MachineRecord{
ID: "vm-1",
Artifact: artifact.Ref,
SystemVolumeID: "vm-1-system",
RuntimeHost: "172.16.0.2",
TapDevice: "fctap0",
Ports: []contracthost.MachinePort{
{Name: contracthost.MachinePortNameSSH, Port: 22, Protocol: contracthost.PortProtocolTCP},
},
Phase: contracthost.MachinePhaseRunning,
PID: 1234,
SocketPath: filepath.Join(root, "runtime", "machine.sock"),
CreatedAt: time.Unix(1700000001, 0).UTC(),
StartedAt: timePtr(time.Unix(1700000002, 0).UTC()),
}
if err := first.CreateMachine(ctx, machine); err != nil {
t.Fatalf("create machine: %v", err)
}
volume := model.VolumeRecord{
ID: "vm-1-system",
Kind: contracthost.VolumeKindSystem,
AttachedMachineID: machineIDPtr("vm-1"),
Path: filepath.Join(root, "machine-disks", "vm-1", "system.img"),
CreatedAt: time.Unix(1700000003, 0).UTC(),
}
if err := first.CreateVolume(ctx, volume); err != nil {
t.Fatalf("create volume: %v", err)
}
operation := model.OperationRecord{
MachineID: "vm-1",
Type: model.MachineOperationCreate,
StartedAt: time.Unix(1700000004, 0).UTC(),
}
if err := first.UpsertOperation(ctx, operation); err != nil {
t.Fatalf("upsert operation: %v", err)
}
second, err := NewFileStore(statePath, opsPath)
if err != nil {
t.Fatalf("reopen file store: %v", err)
}
gotArtifact, err := second.GetArtifact(ctx, artifact.Ref)
if err != nil {
t.Fatalf("get artifact after reopen: %v", err)
}
if gotArtifact.LocalKey != artifact.LocalKey {
t.Fatalf("artifact local key mismatch: got %q want %q", gotArtifact.LocalKey, artifact.LocalKey)
}
gotMachine, err := second.GetMachine(ctx, machine.ID)
if err != nil {
t.Fatalf("get machine after reopen: %v", err)
}
if gotMachine.Phase != contracthost.MachinePhaseRunning {
t.Fatalf("machine phase mismatch: got %q", gotMachine.Phase)
}
if gotMachine.RuntimeHost != machine.RuntimeHost {
t.Fatalf("runtime host mismatch: got %q want %q", gotMachine.RuntimeHost, machine.RuntimeHost)
}
gotVolume, err := second.GetVolume(ctx, volume.ID)
if err != nil {
t.Fatalf("get volume after reopen: %v", err)
}
if gotVolume.AttachedMachineID == nil || *gotVolume.AttachedMachineID != "vm-1" {
t.Fatalf("attached machine mismatch: got %#v", gotVolume.AttachedMachineID)
}
operations, err := second.ListOperations(ctx)
if err != nil {
t.Fatalf("list operations after reopen: %v", err)
}
if len(operations) != 1 {
t.Fatalf("operation count mismatch: got %d want 1", len(operations))
}
if operations[0].Type != model.MachineOperationCreate {
t.Fatalf("operation type mismatch: got %q", operations[0].Type)
}
}
func timePtr(value time.Time) *time.Time {
return &value
}
func machineIDPtr(value contracthost.MachineID) *contracthost.MachineID {
return &value
}

View file

@ -2,11 +2,14 @@ package store
import ( import (
"context" "context"
"errors"
"github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract" contracthost "github.com/getcompanion-ai/computer-host/contract"
) )
var ErrNotFound = errors.New("store: not found")
type Store interface { type Store interface {
PutArtifact(context.Context, model.ArtifactRecord) error PutArtifact(context.Context, model.ArtifactRecord) error
GetArtifact(context.Context, contracthost.ArtifactRef) (*model.ArtifactRecord, error) GetArtifact(context.Context, contracthost.ArtifactRef) (*model.ArtifactRecord, error)
@ -21,4 +24,7 @@ type Store interface {
ListVolumes(context.Context) ([]model.VolumeRecord, error) ListVolumes(context.Context) ([]model.VolumeRecord, error)
UpdateVolume(context.Context, model.VolumeRecord) error UpdateVolume(context.Context, model.VolumeRecord) error
DeleteVolume(context.Context, contracthost.VolumeID) error DeleteVolume(context.Context, contracthost.VolumeID) error
UpsertOperation(context.Context, model.OperationRecord) error
ListOperations(context.Context) ([]model.OperationRecord, error)
DeleteOperation(context.Context, contracthost.MachineID) error
} }

73
main.go
View file

@ -2,76 +2,75 @@ package main
import ( import (
"context" "context"
"encoding/json"
"flag"
"fmt" "fmt"
"io" "net"
"net/http"
"os" "os"
"os/signal" "os/signal"
"strings" "path/filepath"
"syscall" "syscall"
appconfig "github.com/getcompanion-ai/computer-host/internal/config" appconfig "github.com/getcompanion-ai/computer-host/internal/config"
"github.com/getcompanion-ai/computer-host/internal/daemon"
"github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/service" "github.com/getcompanion-ai/computer-host/internal/httpapi"
"github.com/getcompanion-ai/computer-host/internal/store"
) )
type options struct {
MachineID firecracker.MachineID
}
func main() { func main() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop() defer stop()
opts, err := parseOptions(os.Args[1:], os.Stderr)
if err != nil {
exit(err)
}
cfg, err := appconfig.Load() cfg, err := appconfig.Load()
if err != nil { if err != nil {
exit(err) exit(err)
} }
svc, err := service.New(cfg) fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil { if err != nil {
exit(err) exit(err)
} }
state, err := svc.CreateMachine(ctx, service.CreateMachineRequest{ID: opts.MachineID}) runtime, err := firecracker.NewRuntime(cfg.FirecrackerRuntimeConfig())
if err != nil { if err != nil {
exit(err) exit(err)
} }
if err := writeJSON(os.Stdout, state); err != nil { hostDaemon, err := daemon.New(cfg, fileStore, runtime)
if err != nil {
exit(err) exit(err)
} }
} if err := hostDaemon.Reconcile(ctx); err != nil {
exit(err)
func parseOptions(args []string, stderr io.Writer) (options, error) {
fs := flag.NewFlagSet("firecracker-host", flag.ContinueOnError)
fs.SetOutput(stderr)
var machineID string
fs.StringVar(&machineID, "machine-id", "", "machine id to boot")
if err := fs.Parse(args); err != nil {
return options{}, err
} }
machineID = strings.TrimSpace(machineID) handler, err := httpapi.New(hostDaemon)
if machineID == "" { if err != nil {
return options{}, fmt.Errorf("-machine-id is required") exit(err)
} }
return options{MachineID: firecracker.MachineID(machineID)}, nil if err := os.MkdirAll(filepath.Dir(cfg.SocketPath), 0o755); err != nil {
} exit(err)
}
if err := os.Remove(cfg.SocketPath); err != nil && !os.IsNotExist(err) {
exit(err)
}
func writeJSON(w io.Writer, value any) error { listener, err := net.Listen("unix", cfg.SocketPath)
encoder := json.NewEncoder(w) if err != nil {
encoder.SetIndent("", " ") exit(err)
return encoder.Encode(value) }
defer listener.Close()
server := &http.Server{Handler: handler.Routes()}
go func() {
<-ctx.Done()
_ = server.Shutdown(context.Background())
}()
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
exit(err)
}
} }
func exit(err error) { func exit(err error) {