From a12f54ba5db484cf481e0827e34909b8c5cd9548 Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Wed, 8 Apr 2026 19:43:20 +0000 Subject: [PATCH] feat: add guest config injection and host nat wiring --- contract/machines.go | 17 +++- internal/config/config.go | 6 ++ internal/daemon/create.go | 15 ++++ internal/daemon/daemon_test.go | 1 + internal/daemon/files.go | 80 +++++++++++++++++++ internal/daemon/guest_config_test.go | 97 ++++++++++++++++++++++ internal/firecracker/network.go | 115 ++++++++++++++++++++++++++- internal/firecracker/runtime.go | 7 +- 8 files changed, 332 insertions(+), 6 deletions(-) create mode 100644 internal/daemon/guest_config_test.go diff --git a/contract/machines.go b/contract/machines.go index 7632071..09d1650 100644 --- a/contract/machines.go +++ b/contract/machines.go @@ -15,10 +15,21 @@ type Machine struct { StartedAt *time.Time `json:"started_at,omitempty"` } +type GuestConfig struct { + AuthorizedKeys []string `json:"authorized_keys,omitempty"` + LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"` +} + +type GuestLoginWebhook struct { + URL string `json:"url"` + BearerToken string `json:"bearer_token,omitempty"` +} + type CreateMachineRequest struct { - MachineID MachineID `json:"machine_id"` - Artifact ArtifactRef `json:"artifact"` - UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"` + MachineID MachineID `json:"machine_id"` + Artifact ArtifactRef `json:"artifact"` + UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"` + GuestConfig *GuestConfig `json:"guest_config,omitempty"` } type CreateMachineResponse struct { diff --git a/internal/config/config.go b/internal/config/config.go index c4b285e..e353293 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ type Config struct { MachineDisksDir string RuntimeDir string SocketPath string + EgressInterface string FirecrackerBinaryPath string JailerBinaryPath string } @@ -35,6 +36,7 @@ func Load() (Config, error) { MachineDisksDir: filepath.Join(rootDir, "machine-disks"), RuntimeDir: filepath.Join(rootDir, "runtime"), SocketPath: filepath.Join(rootDir, defaultSocketName), + EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")), FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")), JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")), } @@ -73,6 +75,9 @@ func (c Config) Validate() error { if strings.TrimSpace(c.SocketPath) == "" { return fmt.Errorf("socket path is required") } + if strings.TrimSpace(c.EgressInterface) == "" { + return fmt.Errorf("FIRECRACKER_HOST_EGRESS_INTERFACE is required") + } return nil } @@ -80,6 +85,7 @@ func (c Config) Validate() error { func (c Config) FirecrackerRuntimeConfig() firecracker.RuntimeConfig { return firecracker.RuntimeConfig{ RootDir: c.RuntimeDir, + EgressInterface: c.EgressInterface, FirecrackerBinaryPath: c.FirecrackerBinaryPath, JailerBinaryPath: c.JailerBinaryPath, } diff --git a/internal/daemon/create.go b/internal/daemon/create.go index 667fa99..8907495 100644 --- a/internal/daemon/create.go +++ b/internal/daemon/create.go @@ -20,6 +20,9 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi if err := validateArtifactRef(req.Artifact); err != nil { return nil, err } + if err := validateGuestConfig(req.GuestConfig); err != nil { + return nil, err + } unlock := d.lockMachine(req.MachineID) defer unlock() @@ -62,6 +65,17 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil { return nil, err } + removeSystemVolumeOnFailure := true + defer func() { + if !removeSystemVolumeOnFailure { + return + } + _ = os.Remove(systemVolumePath) + _ = os.RemoveAll(filepath.Dir(systemVolumePath)) + }() + if err := injectGuestConfig(ctx, systemVolumePath, req.GuestConfig); err != nil { + return nil, err + } spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath) if err != nil { @@ -140,6 +154,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi return nil, err } + removeSystemVolumeOnFailure = false clearOperation = true return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 880f874..d28abe0 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -215,6 +215,7 @@ func testConfig(root string) appconfig.Config { MachineDisksDir: filepath.Join(root, "machine-disks"), RuntimeDir: filepath.Join(root, "runtime"), SocketPath: filepath.Join(root, "firecracker-host.sock"), + EgressInterface: "eth0", FirecrackerBinaryPath: "/usr/bin/firecracker", JailerBinaryPath: "/usr/bin/jailer", } diff --git a/internal/daemon/files.go b/internal/daemon/files.go index fc535c3..a5c5f35 100644 --- a/internal/daemon/files.go +++ b/internal/daemon/files.go @@ -4,11 +4,13 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" "net/url" "os" + "os/exec" "path/filepath" "strings" @@ -179,6 +181,67 @@ func defaultMachinePorts() []contracthost.MachinePort { } } +func hasGuestConfig(config *contracthost.GuestConfig) bool { + if config == nil { + return false + } + return len(config.AuthorizedKeys) > 0 || config.LoginWebhook != nil +} + +func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error { + if !hasGuestConfig(config) { + return nil + } + stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "guest-config-*") + if err != nil { + return fmt.Errorf("create guest config staging dir: %w", err) + } + defer os.RemoveAll(stagingDir) + + if len(config.AuthorizedKeys) > 0 { + authorizedKeysPath := filepath.Join(stagingDir, "authorized_keys") + payload := []byte(strings.Join(config.AuthorizedKeys, "\n") + "\n") + if err := os.WriteFile(authorizedKeysPath, payload, 0o600); err != nil { + return fmt.Errorf("write authorized_keys staging file: %w", err) + } + if err := replaceExt4File(ctx, imagePath, authorizedKeysPath, "/etc/microagent/authorized_keys"); err != nil { + return err + } + } + + if config.LoginWebhook != nil { + guestConfigPath := filepath.Join(stagingDir, "guest-config.json") + payload, err := json.Marshal(config) + if err != nil { + return fmt.Errorf("marshal guest config: %w", err) + } + if err := os.WriteFile(guestConfigPath, append(payload, '\n'), 0o600); err != nil { + return fmt.Errorf("write guest config staging file: %w", err) + } + if err := replaceExt4File(ctx, imagePath, guestConfigPath, "/etc/microagent/guest-config.json"); err != nil { + return err + } + } + return nil +} + +func replaceExt4File(ctx context.Context, imagePath string, sourcePath string, targetPath string) error { + _ = runDebugFS(ctx, imagePath, fmt.Sprintf("rm %s", targetPath)) + if err := runDebugFS(ctx, imagePath, fmt.Sprintf("write %s %s", sourcePath, targetPath)); err != nil { + return fmt.Errorf("inject %q into %q: %w", targetPath, imagePath, err) + } + return nil +} + +func runDebugFS(ctx context.Context, imagePath string, command string) error { + cmd := exec.CommandContext(ctx, "debugfs", "-w", "-R", command, imagePath) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("debugfs %q on %q: %w: %s", command, imagePath, err, strings.TrimSpace(string(output))) + } + return nil +} + func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID { value := machineID return &value @@ -229,6 +292,23 @@ func validateArtifactRef(ref contracthost.ArtifactRef) error { return nil } +func validateGuestConfig(config *contracthost.GuestConfig) error { + if config == nil { + return nil + } + for i, key := range config.AuthorizedKeys { + if strings.TrimSpace(key) == "" { + return fmt.Errorf("guest_config.authorized_keys[%d] is required", i) + } + } + if config.LoginWebhook != nil { + if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil { + return err + } + } + return nil +} + func validateMachineID(machineID contracthost.MachineID) error { value := strings.TrimSpace(string(machineID)) if value == "" { diff --git a/internal/daemon/guest_config_test.go b/internal/daemon/guest_config_test.go new file mode 100644 index 0000000..1af1fdc --- /dev/null +++ b/internal/daemon/guest_config_test.go @@ -0,0 +1,97 @@ +package daemon + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + contracthost "github.com/getcompanion-ai/computer-host/contract" +) + +func TestInjectGuestConfigWritesAuthorizedKeysAndWebhook(t *testing.T) { + root := t.TempDir() + imagePath := filepath.Join(root, "rootfs.ext4") + if err := buildTestExt4Image(root, imagePath); err != nil { + t.Fatalf("build ext4 image: %v", err) + } + + config := &contracthost.GuestConfig{ + AuthorizedKeys: []string{ + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGuestKeyOne test-1", + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGuestKeyTwo test-2", + }, + LoginWebhook: &contracthost.GuestLoginWebhook{ + URL: "https://example.com/login", + BearerToken: "secret-token", + }, + } + + if err := injectGuestConfig(context.Background(), imagePath, config); err != nil { + t.Fatalf("inject guest config: %v", err) + } + + authorizedKeys, err := readExt4File(imagePath, "/etc/microagent/authorized_keys") + if err != nil { + t.Fatalf("read authorized_keys: %v", err) + } + wantKeys := strings.Join(config.AuthorizedKeys, "\n") + "\n" + if authorizedKeys != wantKeys { + t.Fatalf("authorized_keys mismatch: got %q want %q", authorizedKeys, wantKeys) + } + + guestConfigPayload, err := readExt4File(imagePath, "/etc/microagent/guest-config.json") + if err != nil { + t.Fatalf("read guest-config.json: %v", err) + } + + var guestConfig contracthost.GuestConfig + if err := json.Unmarshal([]byte(guestConfigPayload), &guestConfig); err != nil { + t.Fatalf("unmarshal guest-config.json: %v", err) + } + if guestConfig.LoginWebhook == nil || guestConfig.LoginWebhook.URL != config.LoginWebhook.URL || guestConfig.LoginWebhook.BearerToken != config.LoginWebhook.BearerToken { + t.Fatalf("login webhook mismatch: got %#v want %#v", guestConfig.LoginWebhook, config.LoginWebhook) + } +} + +func buildTestExt4Image(root string, imagePath string) error { + sourceDir := filepath.Join(root, "source") + if err := os.MkdirAll(filepath.Join(sourceDir, "etc", "microagent"), 0o755); err != nil { + return err + } + if err := os.WriteFile(imagePath, nil, 0o644); err != nil { + return err + } + command := exec.Command("truncate", "-s", "16M", imagePath) + output, err := command.CombinedOutput() + if err != nil { + return fmt.Errorf("truncate: %w: %s", err, strings.TrimSpace(string(output))) + } + command = exec.Command("mkfs.ext4", "-q", "-d", sourceDir, "-L", "microagent-root", "-F", imagePath) + output, err = command.CombinedOutput() + if err != nil { + return fmt.Errorf("mkfs.ext4: %w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func readExt4File(imagePath string, targetPath string) (string, error) { + command := exec.Command("debugfs", "-R", "cat "+targetPath, imagePath) + output, err := command.CombinedOutput() + if err != nil { + return "", fmt.Errorf("debugfs cat %q: %w: %s", targetPath, err, strings.TrimSpace(string(output))) + } + lines := strings.Split(string(output), "\n") + filtered := make([]string, 0, len(lines)) + for _, line := range lines { + if strings.HasPrefix(line, "debugfs ") { + continue + } + filtered = append(filtered, line) + } + return strings.TrimPrefix(strings.Join(filtered, "\n"), "\n"), nil +} diff --git a/internal/firecracker/network.go b/internal/firecracker/network.go index 7849eee..68461c0 100644 --- a/internal/firecracker/network.go +++ b/internal/firecracker/network.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/netip" + "os" "os/exec" "strings" ) @@ -15,6 +16,10 @@ const ( defaultNetworkPrefixBits = 30 defaultInterfaceID = "net0" defaultTapPrefix = "fctap" + defaultNFTTableName = "microagentcomputer" + defaultNFTPostrouting = "postrouting" + defaultNFTForward = "forward" + defaultIPForwardPath = "/proc/sys/net/ipv4/ip_forward" ) // NetworkAllocation describes the concrete host-local network values assigned to a machine @@ -40,7 +45,12 @@ type NetworkProvisioner interface { // IPTapProvisioner provisions tap devices through the `ip` CLI. type IPTapProvisioner struct { - runCommand func(context.Context, string, ...string) error + guestCIDR string + egressInterface string + runCommand func(context.Context, string, ...string) error + readCommandOutput func(context.Context, string, ...string) (string, error) + readFile func(string) ([]byte, error) + writeFile func(string, []byte, os.FileMode) error } // GuestIP returns the guest IP address. @@ -143,8 +153,10 @@ func (a *NetworkAllocator) networkForIndex(index int) (NetworkAllocation, error) } // NewIPTapProvisioner returns a provisioner backed by `ip`. -func NewIPTapProvisioner() *IPTapProvisioner { +func NewIPTapProvisioner(guestCIDR string, egressInterface string) *IPTapProvisioner { return &IPTapProvisioner{ + guestCIDR: strings.TrimSpace(guestCIDR), + egressInterface: strings.TrimSpace(egressInterface), runCommand: func(ctx context.Context, name string, args ...string) error { cmd := exec.CommandContext(ctx, name, args...) output, err := cmd.CombinedOutput() @@ -153,6 +165,16 @@ func NewIPTapProvisioner() *IPTapProvisioner { } return nil }, + readCommandOutput: func(ctx context.Context, name string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, name, args...) + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("%s %s: %w: %s", name, strings.Join(args, " "), err, strings.TrimSpace(string(output))) + } + return string(output), nil + }, + readFile: os.ReadFile, + writeFile: os.WriteFile, } } @@ -164,6 +186,9 @@ func (p *IPTapProvisioner) Ensure(ctx context.Context, network NetworkAllocation if strings.TrimSpace(network.TapName) == "" { return fmt.Errorf("tap name is required") } + if err := p.ensureHostNetworking(ctx); err != nil { + return err + } err := p.runCommand(ctx, "ip", "tuntap", "add", "dev", network.TapName, "mode", "tap") if err != nil { @@ -188,6 +213,92 @@ func (p *IPTapProvisioner) Ensure(ctx context.Context, network NetworkAllocation return nil } +func (p *IPTapProvisioner) ensureHostNetworking(ctx context.Context) error { + if strings.TrimSpace(p.egressInterface) == "" { + return fmt.Errorf("egress interface is required") + } + if strings.TrimSpace(p.guestCIDR) == "" { + return fmt.Errorf("guest cidr is required") + } + if err := p.runCommand(ctx, "ip", "link", "show", "dev", p.egressInterface); err != nil { + return fmt.Errorf("validate egress interface %q: %w", p.egressInterface, err) + } + if err := p.ensureIPForwarding(); err != nil { + return err + } + if err := p.ensureNFTTable(ctx); err != nil { + return err + } + if err := p.ensureNFTChain(ctx, defaultNFTPostrouting, "{ type nat hook postrouting priority srcnat; policy accept; }"); err != nil { + return err + } + if err := p.ensureNFTChain(ctx, defaultNFTForward, "{ type filter hook forward priority filter; policy accept; }"); err != nil { + return err + } + if err := p.ensureNFTRule(ctx, defaultNFTPostrouting, "microagentcomputer-postrouting", "ip saddr %s oifname %q counter masquerade comment %q", p.guestCIDR, p.egressInterface, "microagentcomputer-postrouting"); err != nil { + return err + } + if err := p.ensureNFTRule(ctx, defaultNFTForward, "microagentcomputer-forward-out", "iifname %q oifname %q counter accept comment %q", defaultTapPrefix+"*", p.egressInterface, "microagentcomputer-forward-out"); err != nil { + return err + } + if err := p.ensureNFTRule(ctx, defaultNFTForward, "microagentcomputer-forward-in", "iifname %q oifname %q ct state related,established counter accept comment %q", p.egressInterface, defaultTapPrefix+"*", "microagentcomputer-forward-in"); err != nil { + return err + } + return nil +} + +func (p *IPTapProvisioner) ensureIPForwarding() error { + if p.readFile == nil || p.writeFile == nil { + return fmt.Errorf("ip forwarding helpers are required") + } + payload, err := p.readFile(defaultIPForwardPath) + if err != nil { + return fmt.Errorf("read ip forwarding state: %w", err) + } + if strings.TrimSpace(string(payload)) == "1" { + return nil + } + if err := p.writeFile(defaultIPForwardPath, []byte("1\n"), 0o644); err != nil { + return fmt.Errorf("enable ip forwarding: %w", err) + } + return nil +} + +func (p *IPTapProvisioner) ensureNFTTable(ctx context.Context) error { + if _, err := p.readCommandOutput(ctx, "nft", "list", "table", "ip", defaultNFTTableName); err == nil { + return nil + } + if err := p.runCommand(ctx, "nft", "add", "table", "ip", defaultNFTTableName); err != nil { + return fmt.Errorf("ensure nft table %q: %w", defaultNFTTableName, err) + } + return nil +} + +func (p *IPTapProvisioner) ensureNFTChain(ctx context.Context, chain string, definition string) error { + if _, err := p.readCommandOutput(ctx, "nft", "list", "chain", "ip", defaultNFTTableName, chain); err == nil { + return nil + } + if err := p.runCommand(ctx, "nft", "add", "chain", "ip", defaultNFTTableName, chain, definition); err != nil { + return fmt.Errorf("ensure nft chain %q: %w", chain, err) + } + return nil +} + +func (p *IPTapProvisioner) ensureNFTRule(ctx context.Context, chain string, comment string, format string, args ...any) error { + output, err := p.readCommandOutput(ctx, "nft", "list", "chain", "ip", defaultNFTTableName, chain) + if err != nil { + return fmt.Errorf("list nft chain %q: %w", chain, err) + } + if strings.Contains(output, fmt.Sprintf("comment \"%s\"", comment)) { + return nil + } + rule := fmt.Sprintf(format, args...) + if err := p.runCommand(ctx, "nft", "add", "rule", "ip", defaultNFTTableName, chain, rule); err != nil { + return fmt.Errorf("ensure nft rule %q: %w", comment, err) + } + return nil +} + // Remove deletes the tap device if it exists. func (p *IPTapProvisioner) Remove(ctx context.Context, network NetworkAllocation) error { if p == nil || p.runCommand == nil { diff --git a/internal/firecracker/runtime.go b/internal/firecracker/runtime.go index 3aff5b8..9c78c82 100644 --- a/internal/firecracker/runtime.go +++ b/internal/firecracker/runtime.go @@ -21,6 +21,7 @@ var ( type RuntimeConfig struct { RootDir string + EgressInterface string FirecrackerBinaryPath string JailerBinaryPath string } @@ -50,6 +51,10 @@ func NewRuntime(cfg RuntimeConfig) (*Runtime, error) { if jailerBinaryPath == "" { return nil, fmt.Errorf("jailer binary path is required") } + egressInterface := strings.TrimSpace(cfg.EgressInterface) + if egressInterface == "" { + return nil, fmt.Errorf("egress interface is required") + } if err := os.MkdirAll(rootDir, 0o755); err != nil { return nil, fmt.Errorf("create runtime root dir %q: %w", rootDir, err) @@ -65,7 +70,7 @@ func NewRuntime(cfg RuntimeConfig) (*Runtime, error) { firecrackerBinaryPath: firecrackerBinaryPath, jailerBinaryPath: jailerBinaryPath, networkAllocator: allocator, - networkProvisioner: NewIPTapProvisioner(), + networkProvisioner: NewIPTapProvisioner(defaultNetworkCIDR, egressInterface), }, nil }