mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 03:00:42 +00:00
feat: add host-managed backend ssh key injection
This commit is contained in:
parent
a12f54ba5d
commit
9acbf232eb
4 changed files with 138 additions and 4 deletions
|
|
@ -23,6 +23,10 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
if err := validateGuestConfig(req.GuestConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guestConfig, err := d.mergedGuestConfig(req.GuestConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
unlock := d.lockMachine(req.MachineID)
|
||||
defer unlock()
|
||||
|
|
@ -73,7 +77,7 @@ func (d *Daemon) CreateMachine(ctx context.Context, req contracthost.CreateMachi
|
|||
_ = os.Remove(systemVolumePath)
|
||||
_ = os.RemoveAll(filepath.Dir(systemVolumePath))
|
||||
}()
|
||||
if err := injectGuestConfig(ctx, systemVolumePath, req.GuestConfig); err != nil {
|
||||
if err := injectGuestConfig(ctx, systemVolumePath, guestConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -56,13 +56,17 @@ func New(cfg appconfig.Config, store store.Store, runtime Runtime) (*Daemon, err
|
|||
return nil, fmt.Errorf("create daemon dir %q: %w", dir, err)
|
||||
}
|
||||
}
|
||||
return &Daemon{
|
||||
daemon := &Daemon{
|
||||
config: cfg,
|
||||
store: store,
|
||||
runtime: runtime,
|
||||
machineLocks: make(map[contracthost.MachineID]*sync.Mutex),
|
||||
artifactLocks: make(map[string]*sync.Mutex),
|
||||
}, nil
|
||||
}
|
||||
if err := daemon.ensureBackendSSHKeyPair(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return daemon, nil
|
||||
}
|
||||
|
||||
func (d *Daemon) Health(ctx context.Context) (*contracthost.HealthResponse, error) {
|
||||
|
|
|
|||
|
|
@ -75,7 +75,14 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
}
|
||||
|
||||
kernelPayload := []byte("kernel-image")
|
||||
rootFSPayload := []byte("rootfs-image")
|
||||
rootFSImagePath := filepath.Join(root, "guest-rootfs.ext4")
|
||||
if err := buildTestExt4Image(root, rootFSImagePath); err != nil {
|
||||
t.Fatalf("build ext4 image: %v", err)
|
||||
}
|
||||
rootFSPayload, err := os.ReadFile(rootFSImagePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read ext4 image: %v", err)
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/kernel":
|
||||
|
|
@ -94,6 +101,15 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
KernelImageURL: server.URL + "/kernel",
|
||||
RootFSURL: server.URL + "/rootfs",
|
||||
},
|
||||
GuestConfig: &contracthost.GuestConfig{
|
||||
AuthorizedKeys: []string{
|
||||
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestOverrideKey daemon-test",
|
||||
},
|
||||
LoginWebhook: &contracthost.GuestLoginWebhook{
|
||||
URL: "https://example.com/login",
|
||||
BearerToken: "token",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create machine: %v", err)
|
||||
|
|
@ -123,6 +139,20 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
if _, err := os.Stat(runtime.lastSpec.RootFSPath); err != nil {
|
||||
t.Fatalf("system disk not staged: %v", err)
|
||||
}
|
||||
hostAuthorizedKeyBytes, err := os.ReadFile(hostDaemon.backendSSHPublicKeyPath())
|
||||
if err != nil {
|
||||
t.Fatalf("read backend ssh public key: %v", err)
|
||||
}
|
||||
authorizedKeys, err := readExt4File(runtime.lastSpec.RootFSPath, "/etc/microagent/authorized_keys")
|
||||
if err != nil {
|
||||
t.Fatalf("read injected authorized_keys: %v", err)
|
||||
}
|
||||
if !strings.Contains(authorizedKeys, strings.TrimSpace(string(hostAuthorizedKeyBytes))) {
|
||||
t.Fatalf("authorized_keys missing backend ssh key: %q", authorizedKeys)
|
||||
}
|
||||
if !strings.Contains(authorizedKeys, "daemon-test") {
|
||||
t.Fatalf("authorized_keys missing request override key: %q", authorizedKeys)
|
||||
}
|
||||
|
||||
artifact, err := fileStore.GetArtifact(context.Background(), response.Machine.Artifact)
|
||||
if err != nil {
|
||||
|
|
@ -154,6 +184,31 @@ func TestCreateMachineStagesArtifactsAndPersistsState(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewEnsuresBackendSSHKeyPair(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create file store: %v", err)
|
||||
}
|
||||
|
||||
hostDaemon, err := New(cfg, fileStore, &fakeRuntime{})
|
||||
if err != nil {
|
||||
t.Fatalf("create daemon: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(hostDaemon.backendSSHPrivateKeyPath()); err != nil {
|
||||
t.Fatalf("stat backend ssh private key: %v", err)
|
||||
}
|
||||
publicKeyPayload, err := os.ReadFile(hostDaemon.backendSSHPublicKeyPath())
|
||||
if err != nil {
|
||||
t.Fatalf("read backend ssh public key: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(string(publicKeyPayload), "ssh-ed25519 ") {
|
||||
t.Fatalf("unexpected backend ssh public key: %q", string(publicKeyPayload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMachineRejectsNonHTTPArtifactURLs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ func (d *Daemon) machineRuntimeBaseDir(machineID contracthost.MachineID) string
|
|||
return filepath.Join(d.config.RuntimeDir, "machines", string(machineID))
|
||||
}
|
||||
|
||||
func (d *Daemon) backendSSHPrivateKeyPath() string {
|
||||
return filepath.Join(d.config.RootDir, "state", "ssh", "backend_ed25519")
|
||||
}
|
||||
|
||||
func (d *Daemon) backendSSHPublicKeyPath() string {
|
||||
return d.backendSSHPrivateKeyPath() + ".pub"
|
||||
}
|
||||
|
||||
func artifactKey(ref contracthost.ArtifactRef) string {
|
||||
sum := sha256.Sum256([]byte(ref.KernelImageURL + "\n" + ref.RootFSURL))
|
||||
return hex.EncodeToString(sum[:])
|
||||
|
|
@ -181,6 +189,69 @@ func defaultMachinePorts() []contracthost.MachinePort {
|
|||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) ensureBackendSSHKeyPair() error {
|
||||
privateKeyPath := d.backendSSHPrivateKeyPath()
|
||||
publicKeyPath := d.backendSSHPublicKeyPath()
|
||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0o700); err != nil {
|
||||
return fmt.Errorf("create backend ssh dir: %w", err)
|
||||
}
|
||||
privateExists := fileExists(privateKeyPath)
|
||||
publicExists := fileExists(publicKeyPath)
|
||||
switch {
|
||||
case privateExists && publicExists:
|
||||
return nil
|
||||
case privateExists && !publicExists:
|
||||
return d.writeBackendSSHPublicKey(privateKeyPath, publicKeyPath)
|
||||
case !privateExists && publicExists:
|
||||
return fmt.Errorf("backend ssh private key %q is missing while public key exists", privateKeyPath)
|
||||
}
|
||||
|
||||
command := exec.Command("ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", privateKeyPath)
|
||||
output, err := command.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate backend ssh keypair: %w: %s", err, strings.TrimSpace(string(output)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) writeBackendSSHPublicKey(privateKeyPath string, publicKeyPath string) error {
|
||||
command := exec.Command("ssh-keygen", "-y", "-f", privateKeyPath)
|
||||
output, err := command.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("derive backend ssh public key: %w: %s", err, strings.TrimSpace(string(output)))
|
||||
}
|
||||
payload := strings.TrimSpace(string(output)) + "\n"
|
||||
if err := os.WriteFile(publicKeyPath, []byte(payload), 0o644); err != nil {
|
||||
return fmt.Errorf("write backend ssh public key %q: %w", publicKeyPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (d *Daemon) mergedGuestConfig(config *contracthost.GuestConfig) (*contracthost.GuestConfig, error) {
|
||||
hostAuthorizedKey, err := os.ReadFile(d.backendSSHPublicKeyPath())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read backend ssh public key: %w", err)
|
||||
}
|
||||
authorizedKeys := []string{strings.TrimSpace(string(hostAuthorizedKey))}
|
||||
if config != nil {
|
||||
authorizedKeys = append(authorizedKeys, config.AuthorizedKeys...)
|
||||
}
|
||||
|
||||
merged := &contracthost.GuestConfig{
|
||||
AuthorizedKeys: authorizedKeys,
|
||||
}
|
||||
if config != nil && config.LoginWebhook != nil {
|
||||
loginWebhook := *config.LoginWebhook
|
||||
merged.LoginWebhook = &loginWebhook
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func hasGuestConfig(config *contracthost.GuestConfig) bool {
|
||||
if config == nil {
|
||||
return false
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue