diff --git a/contract/storage.go b/contract/storage.go index 59325ff..553077b 100644 --- a/contract/storage.go +++ b/contract/storage.go @@ -3,8 +3,8 @@ package host import "time" type ArtifactRef struct { - ID ArtifactID `json:"id"` - Version ArtifactVersion `json:"version"` + KernelImageURL string `json:"kernel_image_url"` + RootFSURL string `json:"rootfs_url"` } type Volume struct { diff --git a/contract/types.go b/contract/types.go index fa5a379..b13ed42 100644 --- a/contract/types.go +++ b/contract/types.go @@ -1,9 +1,5 @@ package host -type ArtifactID string - -type ArtifactVersion string - type MachineID string type MachinePhase string @@ -13,12 +9,9 @@ type VolumeID string type VolumeKind string const ( - MachinePhasePending MachinePhase = "pending" - MachinePhaseRunning MachinePhase = "running" - MachinePhaseStopping MachinePhase = "stopping" - MachinePhaseStopped MachinePhase = "stopped" - MachinePhaseFailed MachinePhase = "failed" - MachinePhaseDeleting MachinePhase = "deleting" + MachinePhaseRunning MachinePhase = "running" + MachinePhaseStopped MachinePhase = "stopped" + MachinePhaseFailed MachinePhase = "failed" ) const ( diff --git a/internal/config/config.go b/internal/config/config.go index 380643c..c4b285e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,28 +3,40 @@ package config import ( "fmt" "os" + "path/filepath" "strings" "github.com/getcompanion-ai/computer-host/internal/firecracker" ) -// Config contains the minimum host-local settings required to boot a machine. +const defaultSocketName = "firecracker-host.sock" + +// Config contains the host-local daemon settings. type Config struct { RootDir string + StatePath string + OperationsPath string + ArtifactsDir string + MachineDisksDir string + RuntimeDir string + SocketPath string FirecrackerBinaryPath string JailerBinaryPath string - KernelImagePath string - RootFSPath string } -// Load loads and validates the firecracker-host configuration from the environment. +// Load loads and validates the firecracker-host daemon configuration from the environment. func Load() (Config, error) { + rootDir := filepath.Clean(strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_ROOT_DIR"))) cfg := Config{ - RootDir: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_ROOT_DIR")), + RootDir: rootDir, + StatePath: filepath.Join(rootDir, "state", "state.json"), + OperationsPath: filepath.Join(rootDir, "state", "ops.json"), + ArtifactsDir: filepath.Join(rootDir, "artifacts"), + MachineDisksDir: filepath.Join(rootDir, "machine-disks"), + RuntimeDir: filepath.Join(rootDir, "runtime"), + SocketPath: filepath.Join(rootDir, defaultSocketName), FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")), JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")), - KernelImagePath: strings.TrimSpace(os.Getenv("FIRECRACKER_GUEST_KERNEL_PATH")), - RootFSPath: strings.TrimSpace(os.Getenv("FIRECRACKER_GUEST_ROOTFS_PATH")), } if err := cfg.Validate(); err != nil { return Config{}, err @@ -43,19 +55,31 @@ func (c Config) Validate() error { if c.JailerBinaryPath == "" { return fmt.Errorf("JAILER_BINARY_PATH is required") } - if c.KernelImagePath == "" { - return fmt.Errorf("FIRECRACKER_GUEST_KERNEL_PATH is required") + if strings.TrimSpace(c.StatePath) == "" { + return fmt.Errorf("state path is required") } - if c.RootFSPath == "" { - return fmt.Errorf("FIRECRACKER_GUEST_ROOTFS_PATH is required") + if strings.TrimSpace(c.OperationsPath) == "" { + return fmt.Errorf("operations path is required") + } + if strings.TrimSpace(c.ArtifactsDir) == "" { + return fmt.Errorf("artifacts dir is required") + } + if strings.TrimSpace(c.MachineDisksDir) == "" { + return fmt.Errorf("machine disks dir is required") + } + if strings.TrimSpace(c.RuntimeDir) == "" { + return fmt.Errorf("runtime dir is required") + } + if strings.TrimSpace(c.SocketPath) == "" { + return fmt.Errorf("socket path is required") } return nil } -// FirecrackerRuntimeConfig converts the host config into the runtime wrapper's concrete runtime config. +// FirecrackerRuntimeConfig converts the daemon config into the Firecracker runtime config. func (c Config) FirecrackerRuntimeConfig() firecracker.RuntimeConfig { return firecracker.RuntimeConfig{ - RootDir: c.RootDir, + RootDir: c.RuntimeDir, FirecrackerBinaryPath: c.FirecrackerBinaryPath, JailerBinaryPath: c.JailerBinaryPath, } diff --git a/internal/daemon/create.go b/internal/daemon/create.go new file mode 100644 index 0000000..bf9dce2 --- /dev/null +++ b/internal/daemon/create.go @@ -0,0 +1,221 @@ +package daemon + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" + "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachineRequest) (*contracthost.CreateMachineResponse, error) { + if err := validateMachineID(req.MachineID); err != nil { + return nil, err + } + if err := validateArtifactRef(req.Artifact); err != nil { + return nil, err + } + + unlock := d.lockMachine(req.MachineID) + defer unlock() + + if _, err := d.store.GetMachine(ctx, req.MachineID); err == nil { + return nil, fmt.Errorf("machine %q already exists", req.MachineID) + } else if err != nil && err != store.ErrNotFound { + return nil, err + } + + if err := d.store.UpsertOperation(ctx, model.OperationRecord{ + MachineID: req.MachineID, + Type: model.MachineOperationCreate, + StartedAt: time.Now().UTC(), + }); err != nil { + return nil, err + } + + clearOperation := false + defer func() { + if clearOperation { + _ = d.store.DeleteOperation(context.Background(), req.MachineID) + } + }() + + artifact, err := d.ensureArtifact(ctx, req.Artifact) + if err != nil { + return nil, err + } + + userVolumes, err := d.loadAttachableUserVolumes(ctx, req.MachineID, req.UserVolumeIDs) + if err != nil { + return nil, err + } + + systemVolumePath := d.systemVolumePath(req.MachineID) + if err := os.MkdirAll(filepath.Dir(systemVolumePath), 0o755); err != nil { + return nil, fmt.Errorf("create system volume dir for %q: %w", req.MachineID, err) + } + if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil { + return nil, err + } + + spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath) + if err != nil { + return nil, err + } + usedNetworks, err := d.listRunningNetworks(ctx, req.MachineID) + if err != nil { + return nil, err + } + + state, err := d.runtime.Boot(ctx, spec, usedNetworks) + if err != nil { + return nil, err + } + + now := time.Now().UTC() + systemVolumeRecord := model.VolumeRecord{ + ID: d.systemVolumeID(req.MachineID), + Kind: contracthost.VolumeKindSystem, + AttachedMachineID: machineIDPtr(req.MachineID), + SourceArtifact: &req.Artifact, + Pool: model.StoragePoolMachineDisks, + Path: systemVolumePath, + CreatedAt: now, + } + if err := d.store.CreateVolume(ctx, systemVolumeRecord); err != nil { + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + + attachedUserVolumeIDs := make([]contracthost.VolumeID, 0, len(userVolumes)) + for _, volume := range userVolumes { + volume.AttachedMachineID = machineIDPtr(req.MachineID) + if err := d.store.UpdateVolume(ctx, volume); err != nil { + for _, attachedVolumeID := range attachedUserVolumeIDs { + attachedVolume, getErr := d.store.GetVolume(context.Background(), attachedVolumeID) + if getErr == nil { + attachedVolume.AttachedMachineID = nil + _ = d.store.UpdateVolume(context.Background(), *attachedVolume) + } + } + _ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID) + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + attachedUserVolumeIDs = append(attachedUserVolumeIDs, volume.ID) + } + + record := model.MachineRecord{ + ID: req.MachineID, + Artifact: req.Artifact, + SystemVolumeID: systemVolumeRecord.ID, + UserVolumeIDs: append([]contracthost.VolumeID(nil), attachedUserVolumeIDs...), + RuntimeHost: state.RuntimeHost, + TapDevice: state.TapName, + Ports: defaultMachinePorts(), + Phase: contracthost.MachinePhaseRunning, + PID: state.PID, + SocketPath: state.SocketPath, + CreatedAt: now, + StartedAt: state.StartedAt, + } + if err := d.store.CreateMachine(ctx, record); err != nil { + for _, volume := range userVolumes { + volume.AttachedMachineID = nil + _ = d.store.UpdateVolume(context.Background(), volume) + } + _ = d.store.DeleteVolume(context.Background(), systemVolumeRecord.ID) + _ = d.runtime.Delete(context.Background(), *state) + return nil, err + } + + clearOperation = true + return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil +} + +func (d *Daemon) buildMachineSpec(machineID contracthost.MachineID, artifact *model.ArtifactRecord, userVolumes []model.VolumeRecord, systemVolumePath string) (firecracker.MachineSpec, error) { + drives := make([]firecracker.DriveSpec, 0, len(userVolumes)) + for i, volume := range userVolumes { + drives = append(drives, firecracker.DriveSpec{ + ID: fmt.Sprintf("user-%d", i), + Path: volume.Path, + ReadOnly: false, + }) + } + + spec := firecracker.MachineSpec{ + ID: firecracker.MachineID(machineID), + VCPUs: defaultGuestVCPUs, + MemoryMiB: defaultGuestMemoryMiB, + KernelImagePath: artifact.KernelImagePath, + RootFSPath: systemVolumePath, + KernelArgs: defaultGuestKernelArgs, + Drives: drives, + } + if err := spec.Validate(); err != nil { + return firecracker.MachineSpec{}, err + } + return spec, nil +} + +func (d *Daemon) ensureArtifact(ctx context.Context, ref contracthost.ArtifactRef) (*model.ArtifactRecord, error) { + key := artifactKey(ref) + unlock := d.lockArtifact(key) + defer unlock() + + if artifact, err := d.store.GetArtifact(ctx, ref); err == nil { + return artifact, nil + } else if err != store.ErrNotFound { + return nil, err + } + + dir := filepath.Join(d.config.ArtifactsDir, key) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("create artifact dir %q: %w", dir, err) + } + + kernelPath := filepath.Join(dir, "kernel") + rootFSPath := filepath.Join(dir, "rootfs") + if err := downloadFile(ctx, ref.KernelImageURL, kernelPath); err != nil { + return nil, err + } + if err := downloadFile(ctx, ref.RootFSURL, rootFSPath); err != nil { + return nil, err + } + + artifact := model.ArtifactRecord{ + Ref: ref, + LocalKey: key, + LocalDir: dir, + KernelImagePath: kernelPath, + RootFSPath: rootFSPath, + CreatedAt: time.Now().UTC(), + } + if err := d.store.PutArtifact(ctx, artifact); err != nil { + return nil, err + } + return &artifact, nil +} + +func (d *Daemon) loadAttachableUserVolumes(ctx context.Context, machineID contracthost.MachineID, volumeIDs []contracthost.VolumeID) ([]model.VolumeRecord, error) { + volumes := make([]model.VolumeRecord, 0, len(volumeIDs)) + for _, volumeID := range volumeIDs { + volume, err := d.store.GetVolume(ctx, volumeID) + if err != nil { + return nil, err + } + if volume.Kind != contracthost.VolumeKindUser { + return nil, fmt.Errorf("volume %q is not a user volume", volumeID) + } + if volume.AttachedMachineID != nil && *volume.AttachedMachineID != machineID { + return nil, fmt.Errorf("volume %q is already attached to machine %q", volumeID, *volume.AttachedMachineID) + } + volumes = append(volumes, *volume) + } + return volumes, nil +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index d3dd3b2..076273c 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -1,12 +1,95 @@ package daemon import ( + "context" + "fmt" + "os" + "sync" + + appconfig "github.com/getcompanion-ai/computer-host/internal/config" + "github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" ) -type Runtime interface{} +const ( + defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off" + defaultGuestMemoryMiB = int64(512) + defaultGuestVCPUs = int64(1) + defaultSSHPort = uint16(2222) + defaultVNCPort = uint16(6080) + defaultCopyBufferSize = 1024 * 1024 +) + +type Runtime interface { + Boot(context.Context, firecracker.MachineSpec, []firecracker.NetworkAllocation) (*firecracker.MachineState, error) + Inspect(firecracker.MachineState) (*firecracker.MachineState, error) + Delete(context.Context, firecracker.MachineState) error +} type Daemon struct { - Store store.Store - Runtime Runtime + config appconfig.Config + store store.Store + runtime Runtime + + locksMu sync.Mutex + machineLocks map[contracthost.MachineID]*sync.Mutex + artifactLocks map[string]*sync.Mutex +} + +func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if store == nil { + return nil, fmt.Errorf("store is required") + } + if runtime == nil { + return nil, fmt.Errorf("runtime is required") + } + for _, dir := range []string{cfg.ArtifactsDir, cfg.MachineDisksDir, cfg.RuntimeDir} { + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("create daemon dir %q: %w", dir, err) + } + } + return &Daemon{ + config: cfg, + store: store, + runtime: runtime, + machineLocks: make(map[contracthost.MachineID]*sync.Mutex), + artifactLocks: make(map[string]*sync.Mutex), + }, nil +} + +func (d *Daemon) Health(ctx context.Context) (*contracthost.HealthResponse, error) { + if _, err := d.store.ListMachines(ctx); err != nil { + return nil, err + } + return &contracthost.HealthResponse{OK: true}, nil +} + +func (d *Daemon) lockMachine(machineID contracthost.MachineID) func() { + d.locksMu.Lock() + lock, ok := d.machineLocks[machineID] + if !ok { + lock = &sync.Mutex{} + d.machineLocks[machineID] = lock + } + d.locksMu.Unlock() + + lock.Lock() + return lock.Unlock +} + +func (d *Daemon) lockArtifact(key string) func() { + d.locksMu.Lock() + lock, ok := d.artifactLocks[key] + if !ok { + lock = &sync.Mutex{} + d.artifactLocks[key] = lock + } + d.locksMu.Unlock() + + lock.Lock() + return lock.Unlock } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go new file mode 100644 index 0000000..a06da53 --- /dev/null +++ b/internal/daemon/daemon_test.go @@ -0,0 +1,211 @@ +package daemon + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + appconfig "github.com/getcompanion-ai/computer-host/internal/config" + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +type fakeRuntime struct { + bootState firecracker.MachineState + bootCalls int + deleteCalls []firecracker.MachineState + lastSpec firecracker.MachineSpec +} + +func (f *fakeRuntime) Boot(_ context.Context, spec firecracker.MachineSpec, _ []firecracker.NetworkAllocation) (*firecracker.MachineState, error) { + f.bootCalls++ + f.lastSpec = spec + state := f.bootState + return &state, nil +} + +func (f *fakeRuntime) Inspect(state firecracker.MachineState) (*firecracker.MachineState, error) { + copy := state + return ©, nil +} + +func (f *fakeRuntime) Delete(_ context.Context, state firecracker.MachineState) error { + f.deleteCalls = append(f.deleteCalls, state) + return nil +} + +func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) { + t.Parallel() + + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + startedAt := time.Unix(1700000005, 0).UTC() + runtime := &fakeRuntime{ + bootState: firecracker.MachineState{ + ID: "vm-1", + Phase: firecracker.PhaseRunning, + PID: 4321, + RuntimeHost: "172.16.0.2", + SocketPath: filepath.Join(cfg.RuntimeDir, "machines", "vm-1", "root", "run", "firecracker.sock"), + TapName: "fctap0", + StartedAt: &startedAt, + }, + } + + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + kernelPayload := []byte("kernel-image") + rootFSPayload := []byte("rootfs-image") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/kernel": + _, _ = w.Write(kernelPayload) + case "/rootfs": + _, _ = w.Write(rootFSPayload) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + response, err := hostDaemon.CreateMachine(context.Background(), contracthost.CreateMachineRequest{ + MachineID: "vm-1", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: server.URL + "/kernel", + RootFSURL: server.URL + "/rootfs", + }, + }) + if err != nil { + t.Fatalf("create machine: %v", err) + } + + if response.Machine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("machine phase mismatch: got %q", response.Machine.Phase) + } + if response.Machine.RuntimeHost != "172.16.0.2" { + t.Fatalf("runtime host mismatch: got %q", response.Machine.RuntimeHost) + } + if len(response.Machine.Ports) != 2 { + t.Fatalf("machine ports mismatch: got %d want 2", len(response.Machine.Ports)) + } + if runtime.bootCalls != 1 { + t.Fatalf("boot call count mismatch: got %d want 1", runtime.bootCalls) + } + if runtime.lastSpec.KernelImagePath == "" || runtime.lastSpec.RootFSPath == "" { + t.Fatalf("runtime spec paths not populated: %#v", runtime.lastSpec) + } + if _, err := os.Stat(runtime.lastSpec.KernelImagePath); err != nil { + t.Fatalf("kernel artifact not staged: %v", err) + } + if _, err := os.Stat(runtime.lastSpec.RootFSPath); err != nil { + t.Fatalf("system disk not staged: %v", err) + } + + artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact) + if err != nil { + t.Fatalf("get artifact: %v", err) + } + if artifact.KernelImagePath == "" || artifact.RootFSPath == "" { + t.Fatalf("artifact paths missing: %#v", artifact) + } + if payload, err := os.ReadFile(artifact.KernelImagePath); err != nil { + t.Fatalf("read kernel artifact: %v", err) + } else if string(payload) != string(kernelPayload) { + t.Fatalf("kernel artifact payload mismatch: got %q", string(payload)) + } + + machine, err := fileStore.GetMachine(context.Background(), "vm-1") + if err != nil { + t.Fatalf("get machine: %v", err) + } + if machine.SystemVolumeID != "vm-1-system" { + t.Fatalf("system volume mismatch: got %q", machine.SystemVolumeID) + } + + operations, err := fileStore.ListOperations(context.Background()) + if err != nil { + t.Fatalf("list operations: %v", err) + } + if len(operations) != 0 { + t.Fatalf("operation journal should be empty after success: got %d entries", len(operations)) + } +} + +func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) { + t.Parallel() + + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + hostDaemon, err := New(cfg, fileStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + _, err = hostDaemon.CreateMachine(context.Background(), contracthost.CreateMachineRequest{ + MachineID: "vm-1", + Artifact: contracthost.ArtifactRef{ + KernelImageURL: "file:///kernel", + RootFSURL: "https://example.com/rootfs", + }, + }) + if err == nil { + t.Fatal("expected create machine to fail for non-http artifact url") + } + if got := err.Error(); got != "artifact.kernel_image_url must use http or https" { + t.Fatalf("unexpected error: %q", got) + } +} + +func TestDeleteMachineMissingIsNoOp(t *testing.T) { + t.Parallel() + + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + runtime := &fakeRuntime{} + hostDaemon, err := New(cfg, fileStore, runtime) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + if err := hostDaemon.DeleteMachine(context.Background(), "missing"); err != nil { + t.Fatalf("delete missing machine: %v", err) + } + if len(runtime.deleteCalls) != 0 { + t.Fatalf("delete runtime should not be called for missing machine") + } +} + +func testConfig(root string) appconfig.Config { + return appconfig.Config{ + RootDir: root, + StatePath: filepath.Join(root, "state", "state.json"), + OperationsPath: filepath.Join(root, "state", "ops.json"), + ArtifactsDir: filepath.Join(root, "artifacts"), + MachineDisksDir: filepath.Join(root, "machine-disks"), + RuntimeDir: filepath.Join(root, "runtime"), + SocketPath: filepath.Join(root, "firecracker-host.sock"), + FirecrackerBinaryPath: "/usr/bin/firecracker", + JailerBinaryPath: "/usr/bin/jailer", + } +} diff --git a/internal/daemon/files.go b/internal/daemon/files.go new file mode 100644 index 0000000..fc535c3 --- /dev/null +++ b/internal/daemon/files.go @@ -0,0 +1,274 @@ +package daemon + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) systemVolumeID(machineID contracthost.MachineID) contracthost.VolumeID { + return contracthost.VolumeID(fmt.Sprintf("%s-system", machineID)) +} + +func (d *Daemon) systemVolumePath(machineID contracthost.MachineID) string { + return filepath.Join(d.config.MachineDisksDir, string(machineID), "system.img") +} + +func (d *Daemon) machineRuntimeBaseDir(machineID contracthost.MachineID) string { + return filepath.Join(d.config.RuntimeDir, "machines", string(machineID)) +} + +func artifactKey(ref contracthost.ArtifactRef) string { + sum := sha256.Sum256([]byte(ref.KernelImageURL + "\n" + ref.RootFSURL)) + return hex.EncodeToString(sum[:]) +} + +func cloneFile(source string, target string) error { + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return fmt.Errorf("create target dir for %q: %w", target, err) + } + + sourceFile, err := os.Open(source) + if err != nil { + return fmt.Errorf("open source file %q: %w", source, err) + } + defer sourceFile.Close() + + sourceInfo, err := sourceFile.Stat() + if err != nil { + return fmt.Errorf("stat source file %q: %w", source, err) + } + + tmpPath := target + ".tmp" + targetFile, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open target file %q: %w", tmpPath, err) + } + + if _, err := writeSparseFile(targetFile, sourceFile); err != nil { + targetFile.Close() + return fmt.Errorf("copy %q to %q: %w", source, tmpPath, err) + } + if err := targetFile.Truncate(sourceInfo.Size()); err != nil { + targetFile.Close() + return fmt.Errorf("truncate target file %q: %w", tmpPath, err) + } + if err := targetFile.Sync(); err != nil { + targetFile.Close() + return fmt.Errorf("sync target file %q: %w", tmpPath, err) + } + if err := targetFile.Close(); err != nil { + return fmt.Errorf("close target file %q: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, target); err != nil { + return fmt.Errorf("rename target file %q to %q: %w", tmpPath, target, err) + } + if err := syncDir(filepath.Dir(target)); err != nil { + return err + } + return nil +} + +func downloadFile(ctx context.Context, rawURL string, path string) error { + if _, err := os.Stat(path); err == nil { + return nil + } else if !os.IsNotExist(err) { + return fmt.Errorf("stat download target %q: %w", path, err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("create download dir for %q: %w", path, err) + } + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return fmt.Errorf("build download request for %q: %w", rawURL, err) + } + response, err := http.DefaultClient.Do(request) + if err != nil { + return fmt.Errorf("download %q: %w", rawURL, err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("download %q: status %d", rawURL, response.StatusCode) + } + + tmpPath := path + ".tmp" + file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open download target %q: %w", tmpPath, err) + } + + size, err := writeSparseFile(file, response.Body) + if err != nil { + file.Close() + return fmt.Errorf("write download target %q: %w", tmpPath, err) + } + if err := file.Truncate(size); err != nil { + file.Close() + return fmt.Errorf("truncate download target %q: %w", tmpPath, err) + } + if err := file.Sync(); err != nil { + file.Close() + return fmt.Errorf("sync download target %q: %w", tmpPath, err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("close download target %q: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, path); err != nil { + return fmt.Errorf("rename download target %q to %q: %w", tmpPath, path, err) + } + if err := syncDir(filepath.Dir(path)); err != nil { + return err + } + return nil +} + +func writeSparseFile(targetFile *os.File, source io.Reader) (int64, error) { + buffer := make([]byte, defaultCopyBufferSize) + var size int64 + + for { + count, err := source.Read(buffer) + if count > 0 { + chunk := buffer[:count] + if isZeroChunk(chunk) { + if _, seekErr := targetFile.Seek(int64(count), io.SeekCurrent); seekErr != nil { + return size, seekErr + } + } else { + if _, writeErr := targetFile.Write(chunk); writeErr != nil { + return size, writeErr + } + } + size += int64(count) + } + if err == nil { + continue + } + if err == io.EOF { + return size, nil + } + return size, err + } +} + +func isZeroChunk(chunk []byte) bool { + for _, value := range chunk { + if value != 0 { + return false + } + } + return true +} + +func defaultMachinePorts() []contracthost.MachinePort { + return []contracthost.MachinePort{ + {Name: contracthost.MachinePortNameSSH, Port: defaultSSHPort, Protocol: contracthost.PortProtocolTCP}, + {Name: contracthost.MachinePortNameVNC, Port: defaultVNCPort, Protocol: contracthost.PortProtocolTCP}, + } +} + +func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID { + value := machineID + return &value +} + +func machineToContract(record model.MachineRecord) contracthost.Machine { + return contracthost.Machine{ + ID: record.ID, + Artifact: record.Artifact, + SystemVolumeID: record.SystemVolumeID, + UserVolumeIDs: append([]contracthost.VolumeID(nil), record.UserVolumeIDs...), + RuntimeHost: record.RuntimeHost, + Ports: append([]contracthost.MachinePort(nil), record.Ports...), + Phase: record.Phase, + Error: record.Error, + CreatedAt: record.CreatedAt, + StartedAt: record.StartedAt, + } +} + +func machineToRuntimeState(record model.MachineRecord) firecracker.MachineState { + phase := firecracker.PhaseStopped + switch record.Phase { + case contracthost.MachinePhaseRunning: + phase = firecracker.PhaseRunning + case contracthost.MachinePhaseFailed: + phase = firecracker.PhaseFailed + } + return firecracker.MachineState{ + ID: firecracker.MachineID(record.ID), + Phase: phase, + PID: record.PID, + RuntimeHost: record.RuntimeHost, + SocketPath: record.SocketPath, + TapName: record.TapDevice, + StartedAt: record.StartedAt, + Error: record.Error, + } +} + +func validateArtifactRef(ref contracthost.ArtifactRef) error { + if err := validateDownloadURL("artifact.kernel_image_url", ref.KernelImageURL); err != nil { + return err + } + if err := validateDownloadURL("artifact.rootfs_url", ref.RootFSURL); err != nil { + return err + } + return nil +} + +func validateMachineID(machineID contracthost.MachineID) error { + value := strings.TrimSpace(string(machineID)) + if value == "" { + return fmt.Errorf("machine_id is required") + } + if filepath.Base(value) != value { + return fmt.Errorf("machine_id %q must not contain path separators", machineID) + } + return nil +} + +func validateDownloadURL(field string, raw string) error { + value := strings.TrimSpace(raw) + if value == "" { + return fmt.Errorf("%s is required", field) + } + parsed, err := url.Parse(value) + if err != nil { + return fmt.Errorf("%s is invalid: %w", field, err) + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return fmt.Errorf("%s must use http or https", field) + } + if strings.TrimSpace(parsed.Host) == "" { + return fmt.Errorf("%s host is required", field) + } + return nil +} + +func syncDir(path string) error { + dir, err := os.Open(path) + if err != nil { + return fmt.Errorf("open dir %q: %w", path, err) + } + if err := dir.Sync(); err != nil { + dir.Close() + return fmt.Errorf("sync dir %q: %w", path, err) + } + if err := dir.Close(); err != nil { + return fmt.Errorf("close dir %q: %w", path, err) + } + return nil +} diff --git a/internal/daemon/files_test.go b/internal/daemon/files_test.go new file mode 100644 index 0000000..fe197a6 --- /dev/null +++ b/internal/daemon/files_test.go @@ -0,0 +1,94 @@ +package daemon + +import ( + "bytes" + "io" + "os" + "path/filepath" + "syscall" + "testing" +) + +func TestCloneFilePreservesSparseDiskUsage(t *testing.T) { + root := t.TempDir() + sourcePath := filepath.Join(root, "source.img") + targetPath := filepath.Join(root, "target.img") + + sourceFile, err := os.OpenFile(sourcePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + t.Fatalf("open source file: %v", err) + } + if _, err := sourceFile.Write([]byte("head")); err != nil { + sourceFile.Close() + t.Fatalf("write source prefix: %v", err) + } + if _, err := sourceFile.Seek(32<<20, io.SeekStart); err != nil { + sourceFile.Close() + t.Fatalf("seek source hole: %v", err) + } + if _, err := sourceFile.Write([]byte("tail")); err != nil { + sourceFile.Close() + t.Fatalf("write source suffix: %v", err) + } + if err := sourceFile.Close(); err != nil { + t.Fatalf("close source file: %v", err) + } + + sourceInfo, err := os.Stat(sourcePath) + if err != nil { + t.Fatalf("stat source file: %v", err) + } + sourceUsage, err := allocatedBytes(sourcePath) + if err != nil { + t.Fatalf("allocated bytes for source: %v", err) + } + if sourceUsage >= sourceInfo.Size()/2 { + t.Skip("temp filesystem does not expose sparse allocation savings") + } + + if err := cloneFile(sourcePath, targetPath); err != nil { + t.Fatalf("clone sparse file: %v", err) + } + + targetInfo, err := os.Stat(targetPath) + if err != nil { + t.Fatalf("stat target file: %v", err) + } + if targetInfo.Size() != sourceInfo.Size() { + t.Fatalf("target size mismatch: got %d want %d", targetInfo.Size(), sourceInfo.Size()) + } + + targetUsage, err := allocatedBytes(targetPath) + if err != nil { + t.Fatalf("allocated bytes for target: %v", err) + } + if targetUsage >= targetInfo.Size()/2 { + t.Fatalf("target file is not sparse enough: allocated=%d size=%d", targetUsage, targetInfo.Size()) + } + + targetData, err := os.ReadFile(targetPath) + if err != nil { + t.Fatalf("read target file: %v", err) + } + if !bytes.Equal(targetData[:4], []byte("head")) { + t.Fatalf("target prefix mismatch: %q", string(targetData[:4])) + } + if !bytes.Equal(targetData[len(targetData)-4:], []byte("tail")) { + t.Fatalf("target suffix mismatch: %q", string(targetData[len(targetData)-4:])) + } + if !bytes.Equal(targetData[4:4+(1<<20)], make([]byte, 1<<20)) { + t.Fatal("target hole contents were not zeroed") + } +} + +func allocatedBytes(path string) (int64, error) { + info, err := os.Stat(path) + if err != nil { + return 0, err + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return 0, syscall.EINVAL + } + return stat.Blocks * 512, nil +} diff --git a/internal/daemon/lifecycle.go b/internal/daemon/lifecycle.go new file mode 100644 index 0000000..c444934 --- /dev/null +++ b/internal/daemon/lifecycle.go @@ -0,0 +1,327 @@ +package daemon + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/model" + "github.com/getcompanion-ai/computer-host/internal/store" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func (d *Daemon) GetMachine(ctx context.Context, id contracthost.MachineID) (*contracthost.GetMachineResponse, error) { + record, err := d.reconcileMachine(ctx, id) + if err != nil { + return nil, err + } + return &contracthost.GetMachineResponse{Machine: machineToContract(*record)}, nil +} + +func (d *Daemon) ListMachines(ctx context.Context) (*contracthost.ListMachinesResponse, error) { + records, err := d.store.ListMachines(ctx) + if err != nil { + return nil, err + } + + machines := make([]contracthost.Machine, 0, len(records)) + for _, record := range records { + reconciled, err := d.reconcileMachine(ctx, record.ID) + if err != nil { + return nil, err + } + machines = append(machines, machineToContract(*reconciled)) + } + return &contracthost.ListMachinesResponse{Machines: machines}, nil +} + +func (d *Daemon) StopMachine(ctx context.Context, id contracthost.MachineID) error { + unlock := d.lockMachine(id) + defer unlock() + + record, err := d.store.GetMachine(ctx, id) + if err != nil { + return err + } + if record.Phase == contracthost.MachinePhaseStopped { + return nil + } + + if err := d.store.UpsertOperation(ctx, model.OperationRecord{ + MachineID: id, + Type: model.MachineOperationStop, + StartedAt: time.Now().UTC(), + }); err != nil { + return err + } + + clearOperation := false + defer func() { + if clearOperation { + _ = d.store.DeleteOperation(context.Background(), id) + } + }() + + if err := d.stopMachineRecord(ctx, record); err != nil { + return err + } + + clearOperation = true + return nil +} + +func (d *Daemon) DeleteMachine(ctx context.Context, id contracthost.MachineID) error { + unlock := d.lockMachine(id) + defer unlock() + + record, err := d.store.GetMachine(ctx, id) + if err == store.ErrNotFound { + return nil + } + if err != nil { + return err + } + + if err := d.store.UpsertOperation(ctx, model.OperationRecord{ + MachineID: id, + Type: model.MachineOperationDelete, + StartedAt: time.Now().UTC(), + }); err != nil { + return err + } + + clearOperation := false + defer func() { + if clearOperation { + _ = d.store.DeleteOperation(context.Background(), id) + } + }() + + if err := d.deleteMachineRecord(ctx, record); err != nil { + return err + } + + clearOperation = true + return nil +} + +func (d *Daemon) Reconcile(ctx context.Context) error { + operations, err := d.store.ListOperations(ctx) + if err != nil { + return err + } + for _, operation := range operations { + switch operation.Type { + case model.MachineOperationCreate: + if err := d.reconcileCreate(ctx, operation.MachineID); err != nil { + return err + } + case model.MachineOperationStop: + if err := d.reconcileStop(ctx, operation.MachineID); err != nil { + return err + } + case model.MachineOperationDelete: + if err := d.reconcileDelete(ctx, operation.MachineID); err != nil { + return err + } + default: + return fmt.Errorf("unsupported operation type %q", operation.Type) + } + } + + records, err := d.store.ListMachines(ctx) + if err != nil { + return err + } + for _, record := range records { + if _, err := d.reconcileMachine(ctx, record.ID); err != nil { + return err + } + } + return nil +} + +func (d *Daemon) listRunningNetworks(ctx context.Context, ignore contracthost.MachineID) ([]firecracker.NetworkAllocation, error) { + records, err := d.store.ListMachines(ctx) + if err != nil { + return nil, err + } + + networks := make([]firecracker.NetworkAllocation, 0, len(records)) + for _, record := range records { + if record.ID == ignore || record.Phase != contracthost.MachinePhaseRunning { + continue + } + if strings.TrimSpace(record.RuntimeHost) == "" || strings.TrimSpace(record.TapDevice) == "" { + continue + } + network, err := firecracker.AllocationFromGuestIP(record.RuntimeHost, record.TapDevice) + if err != nil { + return nil, err + } + networks = append(networks, network) + } + return networks, nil +} + +func (d *Daemon) reconcileCreate(ctx context.Context, machineID contracthost.MachineID) error { + _, err := d.store.GetMachine(ctx, machineID) + if err == nil { + if _, err := d.reconcileMachine(ctx, machineID); err != nil { + return err + } + return d.store.DeleteOperation(ctx, machineID) + } + if err != store.ErrNotFound { + return err + } + + if err := os.Remove(d.systemVolumePath(machineID)); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("cleanup system volume for %q: %w", machineID, err) + } + if err := d.store.DeleteVolume(ctx, d.systemVolumeID(machineID)); err != nil && err != store.ErrNotFound { + return err + } + if err := d.detachVolumesForMachine(ctx, machineID); err != nil { + return err + } + _ = os.RemoveAll(filepath.Dir(d.systemVolumePath(machineID))) + if err := os.RemoveAll(d.machineRuntimeBaseDir(machineID)); err != nil { + return fmt.Errorf("cleanup runtime dir for %q: %w", machineID, err) + } + return d.store.DeleteOperation(ctx, machineID) +} + +func (d *Daemon) reconcileStop(ctx context.Context, machineID contracthost.MachineID) error { + record, err := d.store.GetMachine(ctx, machineID) + if err == store.ErrNotFound { + return d.store.DeleteOperation(ctx, machineID) + } + if err != nil { + return err + } + if err := d.stopMachineRecord(ctx, record); err != nil { + return err + } + return d.store.DeleteOperation(ctx, machineID) +} + +func (d *Daemon) reconcileDelete(ctx context.Context, machineID contracthost.MachineID) error { + record, err := d.store.GetMachine(ctx, machineID) + if err == store.ErrNotFound { + if err := os.Remove(d.systemVolumePath(machineID)); err != nil && !os.IsNotExist(err) { + return err + } + if err := d.store.DeleteVolume(ctx, d.systemVolumeID(machineID)); err != nil && err != store.ErrNotFound { + return err + } + if err := d.detachVolumesForMachine(ctx, machineID); err != nil { + return err + } + _ = os.RemoveAll(filepath.Dir(d.systemVolumePath(machineID))) + _ = os.RemoveAll(d.machineRuntimeBaseDir(machineID)) + return d.store.DeleteOperation(ctx, machineID) + } + if err != nil { + return err + } + if err := d.deleteMachineRecord(ctx, record); err != nil { + return err + } + return d.store.DeleteOperation(ctx, machineID) +} + +func (d *Daemon) reconcileMachine(ctx context.Context, machineID contracthost.MachineID) (*model.MachineRecord, error) { + unlock := d.lockMachine(machineID) + defer unlock() + + record, err := d.store.GetMachine(ctx, machineID) + if err != nil { + return nil, err + } + if record.Phase != contracthost.MachinePhaseRunning { + return record, nil + } + + state, err := d.runtime.Inspect(machineToRuntimeState(*record)) + if err != nil { + return nil, err + } + if state.Phase == firecracker.PhaseRunning { + return record, nil + } + + if err := d.runtime.Delete(ctx, *state); err != nil { + return nil, err + } + record.Phase = contracthost.MachinePhaseFailed + record.Error = state.Error + record.PID = 0 + record.SocketPath = "" + record.RuntimeHost = "" + record.TapDevice = "" + record.StartedAt = nil + if err := d.store.UpdateMachine(ctx, *record); err != nil { + return nil, err + } + return record, nil +} + +func (d *Daemon) deleteMachineRecord(ctx context.Context, record *model.MachineRecord) error { + if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { + return err + } + if err := d.detachVolumesForMachine(ctx, record.ID); err != nil { + return err + } + + systemVolume, err := d.store.GetVolume(ctx, record.SystemVolumeID) + if err != nil { + return err + } + if err := os.Remove(systemVolume.Path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove system volume %q: %w", systemVolume.Path, err) + } + if err := os.RemoveAll(filepath.Dir(systemVolume.Path)); err != nil { + return fmt.Errorf("remove machine disk dir %q: %w", filepath.Dir(systemVolume.Path), err) + } + if err := d.store.DeleteVolume(ctx, record.SystemVolumeID); err != nil { + return err + } + return d.store.DeleteMachine(ctx, record.ID) +} + +func (d *Daemon) stopMachineRecord(ctx context.Context, record *model.MachineRecord) error { + if err := d.runtime.Delete(ctx, machineToRuntimeState(*record)); err != nil { + return err + } + record.Phase = contracthost.MachinePhaseStopped + record.Error = "" + record.PID = 0 + record.SocketPath = "" + record.RuntimeHost = "" + record.TapDevice = "" + record.StartedAt = nil + return d.store.UpdateMachine(ctx, *record) +} + +func (d *Daemon) detachVolumesForMachine(ctx context.Context, machineID contracthost.MachineID) error { + volumes, err := d.store.ListVolumes(ctx) + if err != nil { + return err + } + for _, volume := range volumes { + if volume.AttachedMachineID == nil || *volume.AttachedMachineID != machineID { + continue + } + volume.AttachedMachineID = nil + if err := d.store.UpdateVolume(ctx, volume); err != nil { + return err + } + } + return nil +} diff --git a/internal/firecracker/launch.go b/internal/firecracker/launch.go index 040a14b..549101a 100644 --- a/internal/firecracker/launch.go +++ b/internal/firecracker/launch.go @@ -14,7 +14,7 @@ import ( const ( defaultCgroupVersion = "2" - defaultFirecrackerInitTimeout = 3 * time.Second + defaultFirecrackerInitTimeout = 10 * time.Second defaultFirecrackerPollInterval = 10 * time.Millisecond defaultRootDriveID = "root_drive" defaultVSockRunDir = "/run" @@ -120,20 +120,34 @@ func waitForSocket(ctx context.Context, client *apiClient, socketPath string) er ticker := time.NewTicker(defaultFirecrackerPollInterval) defer ticker.Stop() + var lastStatErr error + var lastPingErr error + for { select { case <-waitContext.Done(): - return waitContext.Err() + switch { + case lastPingErr != nil: + return fmt.Errorf("%w (socket=%q last_ping_err=%v)", waitContext.Err(), socketPath, lastPingErr) + case lastStatErr != nil: + return fmt.Errorf("%w (socket=%q last_stat_err=%v)", waitContext.Err(), socketPath, lastStatErr) + default: + return fmt.Errorf("%w (socket=%q)", waitContext.Err(), socketPath) + } case <-ticker.C: if _, err := os.Stat(socketPath); err != nil { if os.IsNotExist(err) { + lastStatErr = err continue } return fmt.Errorf("stat socket %q: %w", socketPath, err) } + lastStatErr = nil if err := client.Ping(waitContext); err != nil { + lastPingErr = err continue } + lastPingErr = nil return nil } } diff --git a/internal/firecracker/network.go b/internal/firecracker/network.go index 5af94ea..7849eee 100644 --- a/internal/firecracker/network.go +++ b/internal/firecracker/network.go @@ -4,6 +4,7 @@ import ( "context" "encoding/binary" "fmt" + "net" "net/netip" "os/exec" "strings" @@ -47,6 +48,30 @@ func (n NetworkAllocation) GuestIP() netip.Addr { return n.GuestCIDR.Addr() } +// AllocationFromGuestIP reconstructs the host-side allocation from a guest IP and tap name. +func AllocationFromGuestIP(guestIP string, tapName string) (NetworkAllocation, error) { + parsed := net.ParseIP(strings.TrimSpace(guestIP)) + if parsed == nil { + return NetworkAllocation{}, fmt.Errorf("parse guest ip %q", guestIP) + } + addr, ok := netip.AddrFromSlice(parsed.To4()) + if !ok { + return NetworkAllocation{}, fmt.Errorf("guest ip %q must be IPv4", guestIP) + } + + base := ipv4ToUint32(addr) - 2 + hostIP := uint32ToIPv4(base + 1) + guest := uint32ToIPv4(base + 2) + return NetworkAllocation{ + InterfaceID: defaultInterfaceID, + TapName: strings.TrimSpace(tapName), + HostCIDR: netip.PrefixFrom(hostIP, defaultNetworkPrefixBits), + GuestCIDR: netip.PrefixFrom(guest, defaultNetworkPrefixBits), + GatewayIP: hostIP, + GuestMAC: macForIPv4(guest), + }, nil +} + // NewNetworkAllocator returns a new /30 allocator rooted at the provided IPv4 prefix. func NewNetworkAllocator(cidr string) (*NetworkAllocator, error) { cidr = strings.TrimSpace(cidr) diff --git a/internal/firecracker/paths.go b/internal/firecracker/paths.go index 643600b..118b95e 100644 --- a/internal/firecracker/paths.go +++ b/internal/firecracker/paths.go @@ -3,6 +3,7 @@ package firecracker import ( "fmt" "path/filepath" + "strconv" "strings" ) @@ -45,3 +46,7 @@ func buildMachinePaths(rootDir string, id MachineID, firecrackerBinaryPath strin SocketPath: filepath.Join(chrootRootDir, defaultFirecrackerSocketDir, defaultFirecrackerSocketName), }, nil } + +func procSocketPath(pid int) string { + return filepath.Join("/proc", strconv.Itoa(pid), "root", defaultFirecrackerSocketDir, defaultFirecrackerSocketName) +} diff --git a/internal/firecracker/runtime.go b/internal/firecracker/runtime.go index 05a09a9..785c8bd 100644 --- a/internal/firecracker/runtime.go +++ b/internal/firecracker/runtime.go @@ -8,7 +8,6 @@ import ( "os/exec" "path/filepath" "strings" - "sync" "syscall" "time" ) @@ -27,26 +26,9 @@ type Runtime struct { jailerBinaryPath string networkAllocator *NetworkAllocator networkProvisioner NetworkProvisioner - - mu sync.RWMutex - machines map[MachineID]*managedMachine } -type managedMachine struct { - cmd *exec.Cmd - entered bool - exited chan struct{} - network NetworkAllocation - paths machinePaths - spec MachineSpec - state MachineState - stopping bool -} - -const ( - defaultVSockCIDStart = uint32(3) - defaultVSockID = "vsock0" -) +const debugPreserveFailureEnv = "FIRECRACKER_DEBUG_PRESERVE_FAILURE" func NewRuntime(cfg RuntimeConfig) (*Runtime, error) { rootDir := filepath.Clean(strings.TrimSpace(cfg.RootDir)) @@ -79,64 +61,28 @@ func NewRuntime(cfg RuntimeConfig) (*Runtime, error) { jailerBinaryPath: jailerBinaryPath, networkAllocator: allocator, networkProvisioner: NewIPTapProvisioner(), - machines: make(map[MachineID]*managedMachine), }, nil } -func (r *Runtime) Boot(ctx context.Context, spec MachineSpec) (*MachineState, error) { +func (r *Runtime) Boot(ctx context.Context, spec MachineSpec, usedNetworks []NetworkAllocation) (*MachineState, error) { if err := spec.Validate(); err != nil { return nil, err } - r.mu.Lock() - if _, exists := r.machines[spec.ID]; exists { - r.mu.Unlock() - return nil, fmt.Errorf("machine %q already exists", spec.ID) - } - - usedNetworks := make([]NetworkAllocation, 0, len(r.machines)) - usedVSockCIDs := make(map[uint32]struct{}, len(r.machines)) - for _, machine := range r.machines { - if machine == nil { - continue - } - usedNetworks = append(usedNetworks, machine.network) - if machine.spec.Vsock != nil { - usedVSockCIDs[machine.spec.Vsock.CID] = struct{}{} - } - } - - spec, err := r.resolveVSock(spec, usedVSockCIDs) - if err != nil { - r.mu.Unlock() - return nil, err - } - - r.machines[spec.ID] = &managedMachine{ - spec: spec, - state: MachineState{ - ID: spec.ID, - Phase: PhaseProvisioning, - }, - entered: true, - } - r.mu.Unlock() - cleanup := func(network NetworkAllocation, paths machinePaths, command *exec.Cmd) { + if preserveFailureArtifacts() { + fmt.Fprintf(os.Stderr, "firecracker debug: preserving failure artifacts machine=%s pid=%d socket=%s base=%s\n", spec.ID, pidOf(command), paths.SocketPath, paths.BaseDir) + return + } cleanupStartedProcess(command) _ = r.networkProvisioner.Remove(context.Background(), network) - _ = removeIfExists(hostVSockPath(paths, spec)) if paths.BaseDir != "" { _ = os.RemoveAll(paths.BaseDir) } - r.mu.Lock() - delete(r.machines, spec.ID) - r.mu.Unlock() } network, err := r.networkAllocator.Allocate(usedNetworks) if err != nil { - cleanup(NetworkAllocation{}, machinePaths{}, nil) return nil, err } @@ -159,9 +105,14 @@ func (r *Runtime) Boot(ctx context.Context, spec MachineSpec) (*MachineState, er cleanup(network, paths, nil) return nil, err } + socketPath := paths.SocketPath + if pid := pidOf(command); pid > 0 { + socketPath = procSocketPath(pid) + } + fmt.Fprintf(os.Stderr, "firecracker debug: launched machine=%s pid=%d socket=%s jailer_base=%s\n", spec.ID, pidOf(command), socketPath, paths.JailerBaseDir) - client := newAPIClient(paths.SocketPath) - if err := waitForSocket(ctx, client, paths.SocketPath); err != nil { + client := newAPIClient(socketPath) + if err := waitForSocket(ctx, client, socketPath); err != nil { cleanup(network, paths, command) return nil, fmt.Errorf("wait for firecracker socket: %w", err) } @@ -187,162 +138,74 @@ func (r *Runtime) Boot(ctx context.Context, spec MachineSpec) (*MachineState, er Phase: PhaseRunning, PID: pid, RuntimeHost: network.GuestIP().String(), - SocketPath: paths.SocketPath, + SocketPath: socketPath, TapName: network.TapName, StartedAt: &now, } - - r.mu.Lock() - entry := r.machines[spec.ID] - entry.cmd = command - entry.exited = make(chan struct{}) - entry.network = network - entry.paths = paths - entry.state = state - r.mu.Unlock() - - go r.watchMachine(spec.ID, command, entry.exited) - - out := state - return &out, nil -} - -func (r *Runtime) Inspect(id MachineID) (*MachineState, error) { - r.mu.RLock() - entry, ok := r.machines[id] - r.mu.RUnlock() - if !ok || entry == nil { - return nil, ErrMachineNotFound - } - - state := entry.state - if state.PID > 0 && !processExists(state.PID) { - state.Phase = PhaseStopped - state.PID = 0 - } return &state, nil } -func (r *Runtime) Stop(ctx context.Context, id MachineID) error { - r.mu.RLock() - entry, ok := r.machines[id] - r.mu.RUnlock() - if !ok || entry == nil { - return ErrMachineNotFound +func (r *Runtime) Inspect(state MachineState) (*MachineState, error) { + if state.PID > 0 && !processExists(state.PID) { + state.Phase = PhaseFailed + state.PID = 0 + state.Error = "firecracker process not found" } - if entry.cmd == nil || entry.cmd.Process == nil { - return fmt.Errorf("machine %q has no firecracker process", id) - } - if entry.state.Phase == PhaseStopped { + return &state, nil +} + +func (r *Runtime) Stop(ctx context.Context, state MachineState) error { + if state.PID < 1 || !processExists(state.PID) { return nil } - r.mu.Lock() - entry.stopping = true - process := entry.cmd.Process - exited := entry.exited - r.mu.Unlock() - + process, err := os.FindProcess(state.PID) + if err != nil { + return fmt.Errorf("find process for machine %q: %w", state.ID, err) + } if err := process.Signal(syscall.SIGTERM); err != nil && !errors.Is(err, os.ErrProcessDone) { - return fmt.Errorf("stop machine %q: %w", id, err) + return fmt.Errorf("stop machine %q: %w", state.ID, err) } - select { - case <-exited: - return nil - case <-ctx.Done(): - return ctx.Err() + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + if !processExists(state.PID) { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } } } -func (r *Runtime) Delete(ctx context.Context, id MachineID) error { - r.mu.RLock() - entry, ok := r.machines[id] - r.mu.RUnlock() - if !ok || entry == nil { - return ErrMachineNotFound +func (r *Runtime) Delete(ctx context.Context, state MachineState) error { + if err := r.Stop(ctx, state); err != nil { + return err } - - if entry.state.Phase == PhaseRunning { - if err := r.Stop(ctx, id); err != nil && !errors.Is(err, context.Canceled) { + if strings.TrimSpace(state.RuntimeHost) != "" && strings.TrimSpace(state.TapName) != "" { + network, err := AllocationFromGuestIP(state.RuntimeHost, state.TapName) + if err != nil { + return err + } + if err := r.networkProvisioner.Remove(ctx, network); err != nil { return err } } - if err := r.networkProvisioner.Remove(ctx, entry.network); err != nil { - return err - } - if err := removeIfExists(hostVSockPath(entry.paths, entry.spec)); err != nil { - return err - } - if err := os.RemoveAll(entry.paths.BaseDir); err != nil { - return fmt.Errorf("remove machine dir %q: %w", entry.paths.BaseDir, err) - } - r.mu.Lock() - delete(r.machines, id) - r.mu.Unlock() + paths, err := buildMachinePaths(r.rootDir, state.ID, r.firecrackerBinaryPath) + if err != nil { + return err + } + if err := os.RemoveAll(paths.BaseDir); err != nil { + return fmt.Errorf("remove machine dir %q: %w", paths.BaseDir, err) + } return nil } -func (r *Runtime) watchMachine(id MachineID, command *exec.Cmd, exited chan struct{}) { - err := command.Wait() - close(exited) - - r.mu.Lock() - defer r.mu.Unlock() - - entry, ok := r.machines[id] - if !ok || entry == nil || entry.cmd != command { - return - } - - entry.state.PID = 0 - if entry.stopping { - entry.state.Phase = PhaseStopped - entry.state.Error = "" - entry.stopping = false - return - } - if err != nil { - entry.state.Phase = PhaseError - entry.state.Error = err.Error() - return - } - - entry.state.Phase = PhaseStopped - entry.state.Error = "" -} - -func (r *Runtime) resolveVSock(spec MachineSpec, used map[uint32]struct{}) (MachineSpec, error) { - if spec.Vsock != nil { - if _, exists := used[spec.Vsock.CID]; exists { - return MachineSpec{}, fmt.Errorf("vsock cid %d already in use", spec.Vsock.CID) - } - return spec, nil - } - - cid, err := nextVSockCID(used) - if err != nil { - return MachineSpec{}, err - } - - spec.Vsock = &VsockSpec{ - ID: defaultVSockID, - CID: cid, - Path: string(spec.ID) + ".sock", - } - return spec, nil -} - -func nextVSockCID(used map[uint32]struct{}) (uint32, error) { - for cid := defaultVSockCIDStart; cid != 0; cid++ { - if _, exists := used[cid]; !exists { - return cid, nil - } - } - return 0, fmt.Errorf("vsock cid space exhausted") -} - func processExists(pid int) bool { if pid < 1 { return false @@ -351,12 +214,19 @@ func processExists(pid int) bool { return err == nil || err == syscall.EPERM } -func removeIfExists(path string) error { - if path == "" { - return nil +func pidOf(command *exec.Cmd) int { + if command == nil || command.Process == nil { + return 0 + } + return command.Process.Pid +} + +func preserveFailureArtifacts() bool { + value := strings.TrimSpace(os.Getenv(debugPreserveFailureEnv)) + switch strings.ToLower(value) { + case "1", "true", "yes", "on": + return true + default: + return false } - if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("remove path %q: %w", path, err) - } - return nil } diff --git a/internal/firecracker/state.go b/internal/firecracker/state.go index 33dbf55..da37de7 100644 --- a/internal/firecracker/state.go +++ b/internal/firecracker/state.go @@ -18,14 +18,10 @@ type MachineState struct { } const ( - // PhaseProvisioning means host-local resources are still being prepared. - PhaseProvisioning Phase = "provisioning" // PhaseRunning means the Firecracker process is live. PhaseRunning Phase = "running" // PhaseStopped means the VM is no longer running. PhaseStopped Phase = "stopped" - // PhaseMissing means the machine is not known to the runtime. - PhaseMissing Phase = "missing" - // PhaseError means the runtime observed a terminal failure. - PhaseError Phase = "error" + // PhaseFailed means the runtime observed a terminal failure. + PhaseFailed Phase = "failed" ) diff --git a/internal/httpapi/handlers.go b/internal/httpapi/handlers.go index cadc5bb..baa9df9 100644 --- a/internal/httpapi/handlers.go +++ b/internal/httpapi/handlers.go @@ -2,6 +2,10 @@ package httpapi import ( "context" + "encoding/json" + "fmt" + "net/http" + "strings" contracthost "github.com/getcompanion-ai/computer-host/contract" ) @@ -14,3 +18,133 @@ type Service interface { DeleteMachine(context.Context, contracthost.MachineID) error Health(context.Context) (*contracthost.HealthResponse, error) } + +type Handler struct { + service Service +} + +func New(service Service) (*Handler, error) { + if service == nil { + return nil, fmt.Errorf("service is required") + } + return &Handler{service: service}, nil +} + +func (h *Handler) Routes() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/health", h.handleHealth) + mux.HandleFunc("/machines", h.handleMachines) + mux.HandleFunc("/machines/", h.handleMachine) + return mux +} + +func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeMethodNotAllowed(w) + return + } + response, err := h.service.Health(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, response) +} + +func (h *Handler) handleMachines(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + response, err := h.service.ListMachines(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, response) + case http.MethodPost: + var request contracthost.CreateMachineRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + response, err := h.service.CreateMachine(r.Context(), request) + if err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + writeJSON(w, http.StatusCreated, response) + default: + writeMethodNotAllowed(w) + } +} + +func (h *Handler) handleMachine(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/machines/") + if path == "" { + writeError(w, http.StatusNotFound, fmt.Errorf("machine id is required")) + return + } + parts := strings.Split(path, "/") + machineID := contracthost.MachineID(parts[0]) + + if len(parts) == 1 { + switch r.Method { + case http.MethodGet: + response, err := h.service.GetMachine(r.Context(), machineID) + if err != nil { + writeError(w, statusForError(err), err) + return + } + writeJSON(w, http.StatusOK, response) + case http.MethodDelete: + if err := h.service.DeleteMachine(r.Context(), machineID); err != nil { + writeError(w, statusForError(err), err) + return + } + w.WriteHeader(http.StatusNoContent) + default: + writeMethodNotAllowed(w) + } + return + } + + if len(parts) == 2 && parts[1] == "stop" { + if r.Method != http.MethodPost { + writeMethodNotAllowed(w) + return + } + if err := h.service.StopMachine(r.Context(), machineID); err != nil { + writeError(w, statusForError(err), err) + return + } + w.WriteHeader(http.StatusNoContent) + return + } + + writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) +} + +func statusForError(err error) int { + message := strings.ToLower(err.Error()) + switch { + case strings.Contains(message, "not found"): + return http.StatusNotFound + case strings.Contains(message, "already exists"): + return http.StatusConflict + default: + return http.StatusBadRequest + } +} + +func writeError(w http.ResponseWriter, status int, err error) { + writeJSON(w, status, map[string]string{"error": err.Error()}) +} + +func writeJSON(w http.ResponseWriter, status int, value any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(value) +} + +func writeMethodNotAllowed(w http.ResponseWriter) { + writeError(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed")) +} diff --git a/internal/model/types.go b/internal/model/types.go index 21bac4d..8454b3d 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -17,6 +17,8 @@ const ( type ArtifactRecord struct { Ref contracthost.ArtifactRef + LocalKey string + LocalDir string KernelImagePath string RootFSPath string CreatedAt time.Time @@ -32,6 +34,8 @@ type MachineRecord struct { Ports []contracthost.MachinePort Phase contracthost.MachinePhase Error string + PID int + SocketPath string CreatedAt time.Time StartedAt *time.Time } @@ -45,3 +49,17 @@ type VolumeRecord struct { Path string CreatedAt time.Time } + +type MachineOperation string + +const ( + MachineOperationCreate MachineOperation = "create" + MachineOperationStop MachineOperation = "stop" + MachineOperationDelete MachineOperation = "delete" +) + +type OperationRecord struct { + MachineID contracthost.MachineID + Type MachineOperation + StartedAt time.Time +} diff --git a/internal/service/service.go b/internal/service/service.go deleted file mode 100644 index f85737f..0000000 --- a/internal/service/service.go +++ /dev/null @@ -1,98 +0,0 @@ -package service - -import ( - "context" - "fmt" - "strings" - - appconfig "github.com/getcompanion-ai/computer-host/internal/config" - "github.com/getcompanion-ai/computer-host/internal/firecracker" -) - -// MachineRuntime is the minimum runtime surface the host-local service needs. -type MachineRuntime interface { - Boot(context.Context, firecracker.MachineSpec) (*firecracker.MachineState, error) - Inspect(firecracker.MachineID) (*firecracker.MachineState, error) - Stop(context.Context, firecracker.MachineID) error - Delete(context.Context, firecracker.MachineID) error -} - -// Service manages local machine lifecycle requests on a single host. -type Service struct { - config appconfig.Config - runtime MachineRuntime -} - -// CreateMachineRequest contains the minimum machine creation inputs above the raw runtime layer. -type CreateMachineRequest struct { - ID firecracker.MachineID -} - -const ( - defaultGuestKernelArgs = "console=ttyS0 reboot=k panic=1 pci=off" - defaultGuestMemoryMiB = int64(512) - defaultGuestVCPUs = int64(1) -) - -// New constructs a new host-local service from the app config. -func New(cfg appconfig.Config) (*Service, error) { - if err := cfg.Validate(); err != nil { - return nil, err - } - - runtime, err := firecracker.NewRuntime(cfg.FirecrackerRuntimeConfig()) - if err != nil { - return nil, err - } - - return &Service{ - config: cfg, - runtime: runtime, - }, nil -} - -// CreateMachine boots a new local machine from the single supported host default shape. -func (s *Service) CreateMachine(ctx context.Context, req CreateMachineRequest) (*firecracker.MachineState, error) { - spec, err := s.buildMachineSpec(req) - if err != nil { - return nil, err - } - return s.runtime.Boot(ctx, spec) -} - -// GetMachine returns the current local state for a machine. -func (s *Service) GetMachine(id firecracker.MachineID) (*firecracker.MachineState, error) { - return s.runtime.Inspect(id) -} - -// StopMachine stops a running local machine. -func (s *Service) StopMachine(ctx context.Context, id firecracker.MachineID) error { - return s.runtime.Stop(ctx, id) -} - -// DeleteMachine removes a local machine and its host-local resources. -func (s *Service) DeleteMachine(ctx context.Context, id firecracker.MachineID) error { - return s.runtime.Delete(ctx, id) -} - -func (s *Service) buildMachineSpec(req CreateMachineRequest) (firecracker.MachineSpec, error) { - if s == nil { - return firecracker.MachineSpec{}, fmt.Errorf("service is required") - } - if strings.TrimSpace(string(req.ID)) == "" { - return firecracker.MachineSpec{}, fmt.Errorf("machine id is required") - } - - spec := firecracker.MachineSpec{ - ID: req.ID, - VCPUs: defaultGuestVCPUs, - MemoryMiB: defaultGuestMemoryMiB, - KernelImagePath: s.config.KernelImagePath, - RootFSPath: s.config.RootFSPath, - KernelArgs: defaultGuestKernelArgs, - } - if err := spec.Validate(); err != nil { - return firecracker.MachineSpec{}, err - } - return spec, nil -} diff --git a/internal/store/file_store.go b/internal/store/file_store.go new file mode 100644 index 0000000..bac36f1 --- /dev/null +++ b/internal/store/file_store.go @@ -0,0 +1,413 @@ +package store + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +type FileStore struct { + mu sync.Mutex + statePath string + operationsPath string +} + +type persistedOperations struct { + Operations []model.OperationRecord `json:"operations"` +} + +type persistedState struct { + Artifacts []model.ArtifactRecord `json:"artifacts"` + Machines []model.MachineRecord `json:"machines"` + Volumes []model.VolumeRecord `json:"volumes"` +} + +func NewFileStore(statePath string, operationsPath string) (*FileStore, error) { + store := &FileStore{ + statePath: filepath.Clean(statePath), + operationsPath: filepath.Clean(operationsPath), + } + if err := initializeJSONFile(store.statePath, emptyPersistedState()); err != nil { + return nil, err + } + if err := initializeJSONFile(store.operationsPath, emptyPersistedOperations()); err != nil { + return nil, err + } + return store, nil +} + +func (s *FileStore) PutArtifact(_ context.Context, record model.ArtifactRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.Artifacts { + if state.Artifacts[i].Ref == record.Ref { + state.Artifacts[i] = record + return nil + } + } + state.Artifacts = append(state.Artifacts, record) + return nil + }) +} + +func (s *FileStore) GetArtifact(_ context.Context, ref contracthost.ArtifactRef) (*model.ArtifactRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + for i := range state.Artifacts { + if state.Artifacts[i].Ref == ref { + record := state.Artifacts[i] + return &record, nil + } + } + return nil, ErrNotFound +} + +func (s *FileStore) ListArtifacts(_ context.Context) ([]model.ArtifactRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + return append([]model.ArtifactRecord(nil), state.Artifacts...), nil +} + +func (s *FileStore) CreateMachine(_ context.Context, record model.MachineRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for _, machine := range state.Machines { + if machine.ID == record.ID { + return fmt.Errorf("store: machine %q already exists", record.ID) + } + } + state.Machines = append(state.Machines, record) + return nil + }) +} + +func (s *FileStore) GetMachine(_ context.Context, id contracthost.MachineID) (*model.MachineRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + for i := range state.Machines { + if state.Machines[i].ID == id { + record := state.Machines[i] + return &record, nil + } + } + return nil, ErrNotFound +} + +func (s *FileStore) ListMachines(_ context.Context) ([]model.MachineRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + return append([]model.MachineRecord(nil), state.Machines...), nil +} + +func (s *FileStore) UpdateMachine(_ context.Context, record model.MachineRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.Machines { + if state.Machines[i].ID == record.ID { + state.Machines[i] = record + return nil + } + } + return ErrNotFound + }) +} + +func (s *FileStore) DeleteMachine(_ context.Context, id contracthost.MachineID) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.Machines { + if state.Machines[i].ID == id { + state.Machines = append(state.Machines[:i], state.Machines[i+1:]...) + return nil + } + } + return ErrNotFound + }) +} + +func (s *FileStore) CreateVolume(_ context.Context, record model.VolumeRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for _, volume := range state.Volumes { + if volume.ID == record.ID { + return fmt.Errorf("store: volume %q already exists", record.ID) + } + } + state.Volumes = append(state.Volumes, record) + return nil + }) +} + +func (s *FileStore) GetVolume(_ context.Context, id contracthost.VolumeID) (*model.VolumeRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + for i := range state.Volumes { + if state.Volumes[i].ID == id { + record := state.Volumes[i] + return &record, nil + } + } + return nil, ErrNotFound +} + +func (s *FileStore) ListVolumes(_ context.Context) ([]model.VolumeRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, err := s.readState() + if err != nil { + return nil, err + } + return append([]model.VolumeRecord(nil), state.Volumes...), nil +} + +func (s *FileStore) UpdateVolume(_ context.Context, record model.VolumeRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.Volumes { + if state.Volumes[i].ID == record.ID { + state.Volumes[i] = record + return nil + } + } + return ErrNotFound + }) +} + +func (s *FileStore) DeleteVolume(_ context.Context, id contracthost.VolumeID) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateState(func(state *persistedState) error { + for i := range state.Volumes { + if state.Volumes[i].ID == id { + state.Volumes = append(state.Volumes[:i], state.Volumes[i+1:]...) + return nil + } + } + return ErrNotFound + }) +} + +func (s *FileStore) UpsertOperation(_ context.Context, record model.OperationRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateOperations(func(operations *persistedOperations) error { + for i := range operations.Operations { + if operations.Operations[i].MachineID == record.MachineID { + operations.Operations[i] = record + return nil + } + } + operations.Operations = append(operations.Operations, record) + return nil + }) +} + +func (s *FileStore) ListOperations(_ context.Context) ([]model.OperationRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + operations, err := s.readOperations() + if err != nil { + return nil, err + } + return append([]model.OperationRecord(nil), operations.Operations...), nil +} + +func (s *FileStore) DeleteOperation(_ context.Context, machineID contracthost.MachineID) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.updateOperations(func(operations *persistedOperations) error { + for i := range operations.Operations { + if operations.Operations[i].MachineID == machineID { + operations.Operations = append(operations.Operations[:i], operations.Operations[i+1:]...) + return nil + } + } + return nil + }) +} + +func (s *FileStore) readOperations() (*persistedOperations, error) { + var operations persistedOperations + if err := readJSONFile(s.operationsPath, &operations); err != nil { + return nil, err + } + normalizeOperations(&operations) + return &operations, nil +} + +func (s *FileStore) readState() (*persistedState, error) { + var state persistedState + if err := readJSONFile(s.statePath, &state); err != nil { + return nil, err + } + normalizeState(&state) + return &state, nil +} + +func (s *FileStore) updateOperations(update func(*persistedOperations) error) error { + operations, err := s.readOperations() + if err != nil { + return err + } + if err := update(operations); err != nil { + return err + } + return writeJSONFileAtomically(s.operationsPath, operations) +} + +func (s *FileStore) updateState(update func(*persistedState) error) error { + state, err := s.readState() + if err != nil { + return err + } + if err := update(state); err != nil { + return err + } + return writeJSONFileAtomically(s.statePath, state) +} + +func initializeJSONFile(path string, value any) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("create store dir for %q: %w", path, err) + } + if _, err := os.Stat(path); err == nil { + return nil + } else if !os.IsNotExist(err) { + return fmt.Errorf("stat store file %q: %w", path, err) + } + return writeJSONFileAtomically(path, value) +} + +func readJSONFile(path string, value any) error { + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read store file %q: %w", path, err) + } + if err := json.Unmarshal(data, value); err != nil { + return fmt.Errorf("decode store file %q: %w", path, err) + } + return nil +} + +func writeJSONFileAtomically(path string, value any) error { + payload, err := json.MarshalIndent(value, "", " ") + if err != nil { + return fmt.Errorf("marshal store file %q: %w", path, err) + } + payload = append(payload, '\n') + + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("create store dir for %q: %w", path, err) + } + + tmpPath := path + ".tmp" + file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open temp store file %q: %w", tmpPath, err) + } + if _, err := file.Write(payload); err != nil { + file.Close() + return fmt.Errorf("write temp store file %q: %w", tmpPath, err) + } + if err := file.Sync(); err != nil { + file.Close() + return fmt.Errorf("sync temp store file %q: %w", tmpPath, err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("close temp store file %q: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, path); err != nil { + return fmt.Errorf("rename temp store file %q to %q: %w", tmpPath, path, err) + } + + dir, err := os.Open(filepath.Dir(path)) + if err != nil { + return fmt.Errorf("open store dir for %q: %w", path, err) + } + if err := dir.Sync(); err != nil { + dir.Close() + return fmt.Errorf("sync store dir for %q: %w", path, err) + } + if err := dir.Close(); err != nil { + return fmt.Errorf("close store dir for %q: %w", path, err) + } + return nil +} + +func emptyPersistedState() persistedState { + return persistedState{ + Artifacts: []model.ArtifactRecord{}, + Machines: []model.MachineRecord{}, + Volumes: []model.VolumeRecord{}, + } +} + +func emptyPersistedOperations() persistedOperations { + return persistedOperations{Operations: []model.OperationRecord{}} +} + +func normalizeState(state *persistedState) { + if state.Artifacts == nil { + state.Artifacts = []model.ArtifactRecord{} + } + if state.Machines == nil { + state.Machines = []model.MachineRecord{} + } + if state.Volumes == nil { + state.Volumes = []model.VolumeRecord{} + } +} + +func normalizeOperations(operations *persistedOperations) { + if operations.Operations == nil { + operations.Operations = []model.OperationRecord{} + } +} diff --git a/internal/store/file_store_test.go b/internal/store/file_store_test.go new file mode 100644 index 0000000..573b551 --- /dev/null +++ b/internal/store/file_store_test.go @@ -0,0 +1,130 @@ +package store + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/getcompanion-ai/computer-host/internal/model" + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func TestFileStorePersistsStateAndOperations(t *testing.T) { + t.Parallel() + + root := t.TempDir() + statePath := filepath.Join(root, "state", "state.json") + opsPath := filepath.Join(root, "state", "ops.json") + + ctx := context.Background() + first, err := NewFileStore(statePath, opsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + artifact := model.ArtifactRecord{ + Ref: contracthost.ArtifactRef{ + KernelImageURL: "https://example.com/kernel", + RootFSURL: "https://example.com/rootfs", + }, + LocalKey: "artifact-key", + LocalDir: filepath.Join(root, "artifacts", "artifact-key"), + KernelImagePath: filepath.Join(root, "artifacts", "artifact-key", "kernel"), + RootFSPath: filepath.Join(root, "artifacts", "artifact-key", "rootfs"), + CreatedAt: time.Unix(1700000000, 0).UTC(), + } + if err := first.PutArtifact(ctx, artifact); err != nil { + t.Fatalf("put artifact: %v", err) + } + + machine := model.MachineRecord{ + ID: "vm-1", + Artifact: artifact.Ref, + SystemVolumeID: "vm-1-system", + RuntimeHost: "172.16.0.2", + TapDevice: "fctap0", + Ports: []contracthost.MachinePort{ + {Name: contracthost.MachinePortNameSSH, Port: 22, Protocol: contracthost.PortProtocolTCP}, + }, + Phase: contracthost.MachinePhaseRunning, + PID: 1234, + SocketPath: filepath.Join(root, "runtime", "machine.sock"), + CreatedAt: time.Unix(1700000001, 0).UTC(), + StartedAt: timePtr(time.Unix(1700000002, 0).UTC()), + } + if err := first.CreateMachine(ctx, machine); err != nil { + t.Fatalf("create machine: %v", err) + } + + volume := model.VolumeRecord{ + ID: "vm-1-system", + Kind: contracthost.VolumeKindSystem, + AttachedMachineID: machineIDPtr("vm-1"), + Path: filepath.Join(root, "machine-disks", "vm-1", "system.img"), + CreatedAt: time.Unix(1700000003, 0).UTC(), + } + if err := first.CreateVolume(ctx, volume); err != nil { + t.Fatalf("create volume: %v", err) + } + + operation := model.OperationRecord{ + MachineID: "vm-1", + Type: model.MachineOperationCreate, + StartedAt: time.Unix(1700000004, 0).UTC(), + } + if err := first.UpsertOperation(ctx, operation); err != nil { + t.Fatalf("upsert operation: %v", err) + } + + second, err := NewFileStore(statePath, opsPath) + if err != nil { + t.Fatalf("reopen file store: %v", err) + } + + gotArtifact, err := second.GetArtifact(ctx, artifact.Ref) + if err != nil { + t.Fatalf("get artifact after reopen: %v", err) + } + if gotArtifact.LocalKey != artifact.LocalKey { + t.Fatalf("artifact local key mismatch: got %q want %q", gotArtifact.LocalKey, artifact.LocalKey) + } + + gotMachine, err := second.GetMachine(ctx, machine.ID) + if err != nil { + t.Fatalf("get machine after reopen: %v", err) + } + if gotMachine.Phase != contracthost.MachinePhaseRunning { + t.Fatalf("machine phase mismatch: got %q", gotMachine.Phase) + } + if gotMachine.RuntimeHost != machine.RuntimeHost { + t.Fatalf("runtime host mismatch: got %q want %q", gotMachine.RuntimeHost, machine.RuntimeHost) + } + + gotVolume, err := second.GetVolume(ctx, volume.ID) + if err != nil { + t.Fatalf("get volume after reopen: %v", err) + } + if gotVolume.AttachedMachineID == nil || *gotVolume.AttachedMachineID != "vm-1" { + t.Fatalf("attached machine mismatch: got %#v", gotVolume.AttachedMachineID) + } + + operations, err := second.ListOperations(ctx) + if err != nil { + t.Fatalf("list operations after reopen: %v", err) + } + if len(operations) != 1 { + t.Fatalf("operation count mismatch: got %d want 1", len(operations)) + } + if operations[0].Type != model.MachineOperationCreate { + t.Fatalf("operation type mismatch: got %q", operations[0].Type) + } +} + +func timePtr(value time.Time) *time.Time { + return &value +} + +func machineIDPtr(value contracthost.MachineID) *contracthost.MachineID { + return &value +} diff --git a/internal/store/store.go b/internal/store/store.go index 0604683..542e6dd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -2,11 +2,14 @@ package store import ( "context" + "errors" "github.com/getcompanion-ai/computer-host/internal/model" contracthost "github.com/getcompanion-ai/computer-host/contract" ) +var ErrNotFound = errors.New("store: not found") + type Store interface { PutArtifact(context.Context, model.ArtifactRecord) error GetArtifact(context.Context, contracthost.ArtifactRef) (*model.ArtifactRecord, error) @@ -21,4 +24,7 @@ type Store interface { ListVolumes(context.Context) ([]model.VolumeRecord, error) UpdateVolume(context.Context, model.VolumeRecord) error DeleteVolume(context.Context, contracthost.VolumeID) error + UpsertOperation(context.Context, model.OperationRecord) error + ListOperations(context.Context) ([]model.OperationRecord, error) + DeleteOperation(context.Context, contracthost.MachineID) error } diff --git a/main.go b/main.go index 99f76ee..892b9fa 100644 --- a/main.go +++ b/main.go @@ -2,76 +2,75 @@ package main import ( "context" - "encoding/json" - "flag" "fmt" - "io" + "net" + "net/http" "os" "os/signal" - "strings" + "path/filepath" "syscall" appconfig "github.com/getcompanion-ai/computer-host/internal/config" + "github.com/getcompanion-ai/computer-host/internal/daemon" "github.com/getcompanion-ai/computer-host/internal/firecracker" - "github.com/getcompanion-ai/computer-host/internal/service" + "github.com/getcompanion-ai/computer-host/internal/httpapi" + "github.com/getcompanion-ai/computer-host/internal/store" ) -type options struct { - MachineID firecracker.MachineID -} - func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - opts, err := parseOptions(os.Args[1:], os.Stderr) - if err != nil { - exit(err) - } - cfg, err := appconfig.Load() if err != nil { exit(err) } - svc, err := service.New(cfg) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) if err != nil { exit(err) } - state, err := svc.CreateMachine(ctx, service.CreateMachineRequest{ID: opts.MachineID}) + runtime, err := firecracker.NewRuntime(cfg.FirecrackerRuntimeConfig()) if err != nil { exit(err) } - if err := writeJSON(os.Stdout, state); err != nil { + hostDaemon, err := daemon.New(cfg, fileStore, runtime) + if err != nil { exit(err) } -} - -func parseOptions(args []string, stderr io.Writer) (options, error) { - fs := flag.NewFlagSet("firecracker-host", flag.ContinueOnError) - fs.SetOutput(stderr) - - var machineID string - fs.StringVar(&machineID, "machine-id", "", "machine id to boot") - - if err := fs.Parse(args); err != nil { - return options{}, err + if err := hostDaemon.Reconcile(ctx); err != nil { + exit(err) } - machineID = strings.TrimSpace(machineID) - if machineID == "" { - return options{}, fmt.Errorf("-machine-id is required") + handler, err := httpapi.New(hostDaemon) + if err != nil { + exit(err) } - return options{MachineID: firecracker.MachineID(machineID)}, nil -} + if err := os.MkdirAll(filepath.Dir(cfg.SocketPath), 0o755); err != nil { + exit(err) + } + if err := os.Remove(cfg.SocketPath); err != nil && !os.IsNotExist(err) { + exit(err) + } -func writeJSON(w io.Writer, value any) error { - encoder := json.NewEncoder(w) - encoder.SetIndent("", " ") - return encoder.Encode(value) + listener, err := net.Listen("unix", cfg.SocketPath) + if err != nil { + exit(err) + } + defer listener.Close() + + server := &http.Server{Handler: handler.Routes()} + go func() { + <-ctx.Done() + _ = server.Shutdown(context.Background()) + }() + + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + exit(err) + } } func exit(err error) {