feat: add guest config injection and host nat wiring

This commit is contained in:
Harivansh Rathi 2026-04-08 19:43:20 +00:00
parent 28ca0219d9
commit a12f54ba5d
8 changed files with 332 additions and 6 deletions

View file

@ -15,10 +15,21 @@ type Machine struct {
StartedAt *time.Time `json:"started_at,omitempty"` StartedAt *time.Time `json:"started_at,omitempty"`
} }
type GuestConfig struct {
AuthorizedKeys []string `json:"authorized_keys,omitempty"`
LoginWebhook *GuestLoginWebhook `json:"login_webhook,omitempty"`
}
type GuestLoginWebhook struct {
URL string `json:"url"`
BearerToken string `json:"bearer_token,omitempty"`
}
type CreateMachineRequest struct { type CreateMachineRequest struct {
MachineID MachineID `json:"machine_id"` MachineID MachineID `json:"machine_id"`
Artifact ArtifactRef `json:"artifact"` Artifact ArtifactRef `json:"artifact"`
UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"` UserVolumeIDs []VolumeID `json:"user_volume_ids,omitempty"`
GuestConfig *GuestConfig `json:"guest_config,omitempty"`
} }
type CreateMachineResponse struct { type CreateMachineResponse struct {

View file

@ -20,6 +20,7 @@ type Config struct {
MachineDisksDir string MachineDisksDir string
RuntimeDir string RuntimeDir string
SocketPath string SocketPath string
EgressInterface string
FirecrackerBinaryPath string FirecrackerBinaryPath string
JailerBinaryPath string JailerBinaryPath string
} }
@ -35,6 +36,7 @@ func Load() (Config, error) {
MachineDisksDir: filepath.Join(rootDir, "machine-disks"), MachineDisksDir: filepath.Join(rootDir, "machine-disks"),
RuntimeDir: filepath.Join(rootDir, "runtime"), RuntimeDir: filepath.Join(rootDir, "runtime"),
SocketPath: filepath.Join(rootDir, defaultSocketName), SocketPath: filepath.Join(rootDir, defaultSocketName),
EgressInterface: strings.TrimSpace(os.Getenv("FIRECRACKER_HOST_EGRESS_INTERFACE")),
FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")), FirecrackerBinaryPath: strings.TrimSpace(os.Getenv("FIRECRACKER_BINARY_PATH")),
JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")), JailerBinaryPath: strings.TrimSpace(os.Getenv("JAILER_BINARY_PATH")),
} }
@ -73,6 +75,9 @@ func (c Config) Validate() error {
if strings.TrimSpace(c.SocketPath) == "" { if strings.TrimSpace(c.SocketPath) == "" {
return fmt.Errorf("socket path is required") return fmt.Errorf("socket path is required")
} }
if strings.TrimSpace(c.EgressInterface) == "" {
return fmt.Errorf("FIRECRACKER_HOST_EGRESS_INTERFACE is required")
}
return nil return nil
} }
@ -80,6 +85,7 @@ func (c Config) Validate() error {
func (c Config) FirecrackerRuntimeConfig() firecracker.RuntimeConfig { func (c Config) FirecrackerRuntimeConfig() firecracker.RuntimeConfig {
return firecracker.RuntimeConfig{ return firecracker.RuntimeConfig{
RootDir: c.RuntimeDir, RootDir: c.RuntimeDir,
EgressInterface: c.EgressInterface,
FirecrackerBinaryPath: c.FirecrackerBinaryPath, FirecrackerBinaryPath: c.FirecrackerBinaryPath,
JailerBinaryPath: c.JailerBinaryPath, JailerBinaryPath: c.JailerBinaryPath,
} }

View file

@ -20,6 +20,9 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
if err := validateArtifactRef(req.Artifact); err != nil { if err := validateArtifactRef(req.Artifact); err != nil {
return nil, err return nil, err
} }
if err := validateGuestConfig(req.GuestConfig); err != nil {
return nil, err
}
unlock := d.lockMachine(req.MachineID) unlock := d.lockMachine(req.MachineID)
defer unlock() defer unlock()
@ -62,6 +65,17 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil { if err := cloneFile(artifact.RootFSPath, systemVolumePath); err != nil {
return nil, err return nil, err
} }
removeSystemVolumeOnFailure := true
defer func() {
if !removeSystemVolumeOnFailure {
return
}
_ = os.Remove(systemVolumePath)
_ = os.RemoveAll(filepath.Dir(systemVolumePath))
}()
if err := injectGuestConfig(ctx, systemVolumePath, req.GuestConfig); err != nil {
return nil, err
}
spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath) spec, err := d.buildMachineSpec(req.MachineID, artifact, userVolumes, systemVolumePath)
if err != nil { if err != nil {
@ -140,6 +154,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
return nil, err return nil, err
} }
removeSystemVolumeOnFailure = false
clearOperation = true clearOperation = true
return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil return &contracthost.CreateMachineResponse{Machine: machineToContract(record)}, nil
} }

View file

@ -215,6 +215,7 @@ func testConfig(root string) appconfig.Config {
MachineDisksDir: filepath.Join(root, "machine-disks"), MachineDisksDir: filepath.Join(root, "machine-disks"),
RuntimeDir: filepath.Join(root, "runtime"), RuntimeDir: filepath.Join(root, "runtime"),
SocketPath: filepath.Join(root, "firecracker-host.sock"), SocketPath: filepath.Join(root, "firecracker-host.sock"),
EgressInterface: "eth0",
FirecrackerBinaryPath: "/usr/bin/firecracker", FirecrackerBinaryPath: "/usr/bin/firecracker",
JailerBinaryPath: "/usr/bin/jailer", JailerBinaryPath: "/usr/bin/jailer",
} }

View file

@ -4,11 +4,13 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
@ -179,6 +181,67 @@ func defaultMachinePorts() []contracthost.MachinePort {
} }
} }
func hasGuestConfig(config *contracthost.GuestConfig) bool {
if config == nil {
return false
}
return len(config.AuthorizedKeys) > 0 || config.LoginWebhook != nil
}
func injectGuestConfig(ctx context.Context, imagePath string, config *contracthost.GuestConfig) error {
if !hasGuestConfig(config) {
return nil
}
stagingDir, err := os.MkdirTemp(filepath.Dir(imagePath), "guest-config-*")
if err != nil {
return fmt.Errorf("create guest config staging dir: %w", err)
}
defer os.RemoveAll(stagingDir)
if len(config.AuthorizedKeys) > 0 {
authorizedKeysPath := filepath.Join(stagingDir, "authorized_keys")
payload := []byte(strings.Join(config.AuthorizedKeys, "\n") + "\n")
if err := os.WriteFile(authorizedKeysPath, payload, 0o600); err != nil {
return fmt.Errorf("write authorized_keys staging file: %w", err)
}
if err := replaceExt4File(ctx, imagePath, authorizedKeysPath, "/etc/microagent/authorized_keys"); err != nil {
return err
}
}
if config.LoginWebhook != nil {
guestConfigPath := filepath.Join(stagingDir, "guest-config.json")
payload, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("marshal guest config: %w", err)
}
if err := os.WriteFile(guestConfigPath, append(payload, '\n'), 0o600); err != nil {
return fmt.Errorf("write guest config staging file: %w", err)
}
if err := replaceExt4File(ctx, imagePath, guestConfigPath, "/etc/microagent/guest-config.json"); err != nil {
return err
}
}
return nil
}
func replaceExt4File(ctx context.Context, imagePath string, sourcePath string, targetPath string) error {
_ = runDebugFS(ctx, imagePath, fmt.Sprintf("rm %s", targetPath))
if err := runDebugFS(ctx, imagePath, fmt.Sprintf("write %s %s", sourcePath, targetPath)); err != nil {
return fmt.Errorf("inject %q into %q: %w", targetPath, imagePath, err)
}
return nil
}
func runDebugFS(ctx context.Context, imagePath string, command string) error {
cmd := exec.CommandContext(ctx, "debugfs", "-w", "-R", command, imagePath)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("debugfs %q on %q: %w: %s", command, imagePath, err, strings.TrimSpace(string(output)))
}
return nil
}
func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID { func machineIDPtr(machineID contracthost.MachineID) *contracthost.MachineID {
value := machineID value := machineID
return &value return &value
@ -229,6 +292,23 @@ func validateArtifactRef(ref contracthost.ArtifactRef) error {
return nil return nil
} }
func validateGuestConfig(config *contracthost.GuestConfig) error {
if config == nil {
return nil
}
for i, key := range config.AuthorizedKeys {
if strings.TrimSpace(key) == "" {
return fmt.Errorf("guest_config.authorized_keys[%d] is required", i)
}
}
if config.LoginWebhook != nil {
if err := validateDownloadURL("guest_config.login_webhook.url", config.LoginWebhook.URL); err != nil {
return err
}
}
return nil
}
func validateMachineID(machineID contracthost.MachineID) error { func validateMachineID(machineID contracthost.MachineID) error {
value := strings.TrimSpace(string(machineID)) value := strings.TrimSpace(string(machineID))
if value == "" { if value == "" {

View file

@ -0,0 +1,97 @@
package daemon
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
func TestInjectGuestConfigWritesAuthorizedKeysAndWebhook(t *testing.T) {
root := t.TempDir()
imagePath := filepath.Join(root, "rootfs.ext4")
if err := buildTestExt4Image(root, imagePath); err != nil {
t.Fatalf("build ext4 image: %v", err)
}
config := &contracthost.GuestConfig{
AuthorizedKeys: []string{
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGuestKeyOne test-1",
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGuestKeyTwo test-2",
},
LoginWebhook: &contracthost.GuestLoginWebhook{
URL: "https://example.com/login",
BearerToken: "secret-token",
},
}
if err := injectGuestConfig(context.Background(), imagePath, config); err != nil {
t.Fatalf("inject guest config: %v", err)
}
authorizedKeys, err := readExt4File(imagePath, "/etc/microagent/authorized_keys")
if err != nil {
t.Fatalf("read authorized_keys: %v", err)
}
wantKeys := strings.Join(config.AuthorizedKeys, "\n") + "\n"
if authorizedKeys != wantKeys {
t.Fatalf("authorized_keys mismatch: got %q want %q", authorizedKeys, wantKeys)
}
guestConfigPayload, err := readExt4File(imagePath, "/etc/microagent/guest-config.json")
if err != nil {
t.Fatalf("read guest-config.json: %v", err)
}
var guestConfig contracthost.GuestConfig
if err := json.Unmarshal([]byte(guestConfigPayload), &guestConfig); err != nil {
t.Fatalf("unmarshal guest-config.json: %v", err)
}
if guestConfig.LoginWebhook == nil || guestConfig.LoginWebhook.URL != config.LoginWebhook.URL || guestConfig.LoginWebhook.BearerToken != config.LoginWebhook.BearerToken {
t.Fatalf("login webhook mismatch: got %#v want %#v", guestConfig.LoginWebhook, config.LoginWebhook)
}
}
func buildTestExt4Image(root string, imagePath string) error {
sourceDir := filepath.Join(root, "source")
if err := os.MkdirAll(filepath.Join(sourceDir, "etc", "microagent"), 0o755); err != nil {
return err
}
if err := os.WriteFile(imagePath, nil, 0o644); err != nil {
return err
}
command := exec.Command("truncate", "-s", "16M", imagePath)
output, err := command.CombinedOutput()
if err != nil {
return fmt.Errorf("truncate: %w: %s", err, strings.TrimSpace(string(output)))
}
command = exec.Command("mkfs.ext4", "-q", "-d", sourceDir, "-L", "microagent-root", "-F", imagePath)
output, err = command.CombinedOutput()
if err != nil {
return fmt.Errorf("mkfs.ext4: %w: %s", err, strings.TrimSpace(string(output)))
}
return nil
}
func readExt4File(imagePath string, targetPath string) (string, error) {
command := exec.Command("debugfs", "-R", "cat "+targetPath, imagePath)
output, err := command.CombinedOutput()
if err != nil {
return "", fmt.Errorf("debugfs cat %q: %w: %s", targetPath, err, strings.TrimSpace(string(output)))
}
lines := strings.Split(string(output), "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
if strings.HasPrefix(line, "debugfs ") {
continue
}
filtered = append(filtered, line)
}
return strings.TrimPrefix(strings.Join(filtered, "\n"), "\n"), nil
}

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"os/exec" "os/exec"
"strings" "strings"
) )
@ -15,6 +16,10 @@ const (
defaultNetworkPrefixBits = 30 defaultNetworkPrefixBits = 30
defaultInterfaceID = "net0" defaultInterfaceID = "net0"
defaultTapPrefix = "fctap" defaultTapPrefix = "fctap"
defaultNFTTableName = "microagentcomputer"
defaultNFTPostrouting = "postrouting"
defaultNFTForward = "forward"
defaultIPForwardPath = "/proc/sys/net/ipv4/ip_forward"
) )
// NetworkAllocation describes the concrete host-local network values assigned to a machine // NetworkAllocation describes the concrete host-local network values assigned to a machine
@ -40,7 +45,12 @@ type NetworkProvisioner interface {
// IPTapProvisioner provisions tap devices through the `ip` CLI. // IPTapProvisioner provisions tap devices through the `ip` CLI.
type IPTapProvisioner struct { type IPTapProvisioner struct {
runCommand func(context.Context, string, ...string) error guestCIDR string
egressInterface string
runCommand func(context.Context, string, ...string) error
readCommandOutput func(context.Context, string, ...string) (string, error)
readFile func(string) ([]byte, error)
writeFile func(string, []byte, os.FileMode) error
} }
// GuestIP returns the guest IP address. // GuestIP returns the guest IP address.
@ -143,8 +153,10 @@ func (a *NetworkAllocator) networkForIndex(index int) (NetworkAllocation, error)
} }
// NewIPTapProvisioner returns a provisioner backed by `ip`. // NewIPTapProvisioner returns a provisioner backed by `ip`.
func NewIPTapProvisioner() *IPTapProvisioner { func NewIPTapProvisioner(guestCIDR string, egressInterface string) *IPTapProvisioner {
return &IPTapProvisioner{ return &IPTapProvisioner{
guestCIDR: strings.TrimSpace(guestCIDR),
egressInterface: strings.TrimSpace(egressInterface),
runCommand: func(ctx context.Context, name string, args ...string) error { runCommand: func(ctx context.Context, name string, args ...string) error {
cmd := exec.CommandContext(ctx, name, args...) cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
@ -153,6 +165,16 @@ func NewIPTapProvisioner() *IPTapProvisioner {
} }
return nil return nil
}, },
readCommandOutput: func(ctx context.Context, name string, args ...string) (string, error) {
cmd := exec.CommandContext(ctx, name, args...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("%s %s: %w: %s", name, strings.Join(args, " "), err, strings.TrimSpace(string(output)))
}
return string(output), nil
},
readFile: os.ReadFile,
writeFile: os.WriteFile,
} }
} }
@ -164,6 +186,9 @@ func (p *IPTapProvisioner) Ensure(ctx context.Context, network NetworkAllocation
if strings.TrimSpace(network.TapName) == "" { if strings.TrimSpace(network.TapName) == "" {
return fmt.Errorf("tap name is required") return fmt.Errorf("tap name is required")
} }
if err := p.ensureHostNetworking(ctx); err != nil {
return err
}
err := p.runCommand(ctx, "ip", "tuntap", "add", "dev", network.TapName, "mode", "tap") err := p.runCommand(ctx, "ip", "tuntap", "add", "dev", network.TapName, "mode", "tap")
if err != nil { if err != nil {
@ -188,6 +213,92 @@ func (p *IPTapProvisioner) Ensure(ctx context.Context, network NetworkAllocation
return nil return nil
} }
func (p *IPTapProvisioner) ensureHostNetworking(ctx context.Context) error {
if strings.TrimSpace(p.egressInterface) == "" {
return fmt.Errorf("egress interface is required")
}
if strings.TrimSpace(p.guestCIDR) == "" {
return fmt.Errorf("guest cidr is required")
}
if err := p.runCommand(ctx, "ip", "link", "show", "dev", p.egressInterface); err != nil {
return fmt.Errorf("validate egress interface %q: %w", p.egressInterface, err)
}
if err := p.ensureIPForwarding(); err != nil {
return err
}
if err := p.ensureNFTTable(ctx); err != nil {
return err
}
if err := p.ensureNFTChain(ctx, defaultNFTPostrouting, "{ type nat hook postrouting priority srcnat; policy accept; }"); err != nil {
return err
}
if err := p.ensureNFTChain(ctx, defaultNFTForward, "{ type filter hook forward priority filter; policy accept; }"); err != nil {
return err
}
if err := p.ensureNFTRule(ctx, defaultNFTPostrouting, "microagentcomputer-postrouting", "ip saddr %s oifname %q counter masquerade comment %q", p.guestCIDR, p.egressInterface, "microagentcomputer-postrouting"); err != nil {
return err
}
if err := p.ensureNFTRule(ctx, defaultNFTForward, "microagentcomputer-forward-out", "iifname %q oifname %q counter accept comment %q", defaultTapPrefix+"*", p.egressInterface, "microagentcomputer-forward-out"); err != nil {
return err
}
if err := p.ensureNFTRule(ctx, defaultNFTForward, "microagentcomputer-forward-in", "iifname %q oifname %q ct state related,established counter accept comment %q", p.egressInterface, defaultTapPrefix+"*", "microagentcomputer-forward-in"); err != nil {
return err
}
return nil
}
func (p *IPTapProvisioner) ensureIPForwarding() error {
if p.readFile == nil || p.writeFile == nil {
return fmt.Errorf("ip forwarding helpers are required")
}
payload, err := p.readFile(defaultIPForwardPath)
if err != nil {
return fmt.Errorf("read ip forwarding state: %w", err)
}
if strings.TrimSpace(string(payload)) == "1" {
return nil
}
if err := p.writeFile(defaultIPForwardPath, []byte("1\n"), 0o644); err != nil {
return fmt.Errorf("enable ip forwarding: %w", err)
}
return nil
}
func (p *IPTapProvisioner) ensureNFTTable(ctx context.Context) error {
if _, err := p.readCommandOutput(ctx, "nft", "list", "table", "ip", defaultNFTTableName); err == nil {
return nil
}
if err := p.runCommand(ctx, "nft", "add", "table", "ip", defaultNFTTableName); err != nil {
return fmt.Errorf("ensure nft table %q: %w", defaultNFTTableName, err)
}
return nil
}
func (p *IPTapProvisioner) ensureNFTChain(ctx context.Context, chain string, definition string) error {
if _, err := p.readCommandOutput(ctx, "nft", "list", "chain", "ip", defaultNFTTableName, chain); err == nil {
return nil
}
if err := p.runCommand(ctx, "nft", "add", "chain", "ip", defaultNFTTableName, chain, definition); err != nil {
return fmt.Errorf("ensure nft chain %q: %w", chain, err)
}
return nil
}
func (p *IPTapProvisioner) ensureNFTRule(ctx context.Context, chain string, comment string, format string, args ...any) error {
output, err := p.readCommandOutput(ctx, "nft", "list", "chain", "ip", defaultNFTTableName, chain)
if err != nil {
return fmt.Errorf("list nft chain %q: %w", chain, err)
}
if strings.Contains(output, fmt.Sprintf("comment \"%s\"", comment)) {
return nil
}
rule := fmt.Sprintf(format, args...)
if err := p.runCommand(ctx, "nft", "add", "rule", "ip", defaultNFTTableName, chain, rule); err != nil {
return fmt.Errorf("ensure nft rule %q: %w", comment, err)
}
return nil
}
// Remove deletes the tap device if it exists. // Remove deletes the tap device if it exists.
func (p *IPTapProvisioner) Remove(ctx context.Context, network NetworkAllocation) error { func (p *IPTapProvisioner) Remove(ctx context.Context, network NetworkAllocation) error {
if p == nil || p.runCommand == nil { if p == nil || p.runCommand == nil {

View file

@ -21,6 +21,7 @@ var (
type RuntimeConfig struct { type RuntimeConfig struct {
RootDir string RootDir string
EgressInterface string
FirecrackerBinaryPath string FirecrackerBinaryPath string
JailerBinaryPath string JailerBinaryPath string
} }
@ -50,6 +51,10 @@ func NewRuntime(cfg RuntimeConfig) (*Runtime, error) {
if jailerBinaryPath == "" { if jailerBinaryPath == "" {
return nil, fmt.Errorf("jailer binary path is required") return nil, fmt.Errorf("jailer binary path is required")
} }
egressInterface := strings.TrimSpace(cfg.EgressInterface)
if egressInterface == "" {
return nil, fmt.Errorf("egress interface is required")
}
if err := os.MkdirAll(rootDir, 0o755); err != nil { if err := os.MkdirAll(rootDir, 0o755); err != nil {
return nil, fmt.Errorf("create runtime root dir %q: %w", rootDir, err) return nil, fmt.Errorf("create runtime root dir %q: %w", rootDir, err)
@ -65,7 +70,7 @@ func NewRuntime(cfg RuntimeConfig) (*Runtime, error) {
firecrackerBinaryPath: firecrackerBinaryPath, firecrackerBinaryPath: firecrackerBinaryPath,
jailerBinaryPath: jailerBinaryPath, jailerBinaryPath: jailerBinaryPath,
networkAllocator: allocator, networkAllocator: allocator,
networkProvisioner: NewIPTapProvisioner(), networkProvisioner: NewIPTapProvisioner(defaultNetworkCIDR, egressInterface),
}, nil }, nil
} }