feat: firecracker mmds identity

This commit is contained in:
Harivansh Rathi 2026-04-10 00:53:47 +00:00
parent 500354cd9b
commit 3eb610b703
23 changed files with 1813 additions and 263 deletions

View file

@ -0,0 +1,194 @@
package daemon
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"github.com/getcompanion-ai/computer-host/internal/model"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
type restoredSnapshotArtifact struct {
Artifact contracthost.SnapshotArtifact
LocalPath string
}
func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) {
artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2)
memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, memoryArtifact)
vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, vmstateArtifact)
for _, diskPath := range diskPaths {
base := filepath.Base(diskPath)
diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath)
if err != nil {
return nil, err
}
artifacts = append(artifacts, diskArtifact)
}
sort.Slice(artifacts, func(i, j int) bool {
return artifacts[i].ID < artifacts[j].ID
})
return artifacts, nil
}
func snapshotArtifactRecord(id string, kind contracthost.SnapshotArtifactKind, name, path string) (model.SnapshotArtifactRecord, error) {
size, err := fileSize(path)
if err != nil {
return model.SnapshotArtifactRecord{}, err
}
sum, err := sha256File(path)
if err != nil {
return model.SnapshotArtifactRecord{}, err
}
return model.SnapshotArtifactRecord{
ID: id,
Kind: kind,
Name: name,
LocalPath: path,
SizeBytes: size,
SHA256Hex: sum,
}, nil
}
func sha256File(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", fmt.Errorf("open %q for sha256: %w", path, err)
}
defer func() { _ = file.Close() }()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", fmt.Errorf("hash %q: %w", path, err)
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
func uploadSnapshotArtifact(ctx context.Context, localPath string, parts []contracthost.SnapshotUploadPart) ([]contracthost.UploadedSnapshotPart, error) {
if len(parts) == 0 {
return nil, fmt.Errorf("upload session has no parts")
}
file, err := os.Open(localPath)
if err != nil {
return nil, fmt.Errorf("open artifact %q: %w", localPath, err)
}
defer func() { _ = file.Close() }()
client := &http.Client{}
completed := make([]contracthost.UploadedSnapshotPart, 0, len(parts))
for _, part := range parts {
reader := io.NewSectionReader(file, part.OffsetBytes, part.SizeBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, part.UploadURL, io.NopCloser(reader))
if err != nil {
return nil, fmt.Errorf("build upload part %d: %w", part.PartNumber, err)
}
req.ContentLength = part.SizeBytes
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("upload part %d: %w", part.PartNumber, err)
}
_ = resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("upload part %d returned %d", part.PartNumber, resp.StatusCode)
}
etag := strings.TrimSpace(resp.Header.Get("ETag"))
if etag == "" {
return nil, fmt.Errorf("upload part %d returned empty etag", part.PartNumber)
}
completed = append(completed, contracthost.UploadedSnapshotPart{
PartNumber: part.PartNumber,
ETag: etag,
})
}
sort.Slice(completed, func(i, j int) bool {
return completed[i].PartNumber < completed[j].PartNumber
})
return completed, nil
}
func downloadDurableSnapshotArtifacts(ctx context.Context, root string, artifacts []contracthost.SnapshotArtifact) (map[string]restoredSnapshotArtifact, error) {
if len(artifacts) == 0 {
return nil, fmt.Errorf("restore snapshot is missing artifacts")
}
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, fmt.Errorf("create restore staging dir %q: %w", root, err)
}
client := &http.Client{}
restored := make(map[string]restoredSnapshotArtifact, len(artifacts))
for _, artifact := range artifacts {
if strings.TrimSpace(artifact.DownloadURL) == "" {
return nil, fmt.Errorf("artifact %q is missing download url", artifact.ID)
}
localPath := filepath.Join(root, artifact.Name)
if err := downloadSnapshotArtifact(ctx, client, artifact.DownloadURL, localPath); err != nil {
return nil, err
}
if expectedSHA := strings.TrimSpace(artifact.SHA256Hex); expectedSHA != "" {
actualSHA, err := sha256File(localPath)
if err != nil {
return nil, err
}
if !strings.EqualFold(actualSHA, expectedSHA) {
_ = os.Remove(localPath)
return nil, fmt.Errorf("restore artifact %q sha256 mismatch: got %s want %s", artifact.Name, actualSHA, expectedSHA)
}
}
restored[artifact.Name] = restoredSnapshotArtifact{
Artifact: artifact,
LocalPath: localPath,
}
}
return restored, nil
}
func downloadSnapshotArtifact(ctx context.Context, client *http.Client, sourceURL, targetPath string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
if err != nil {
return fmt.Errorf("build restore download request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("download durable snapshot artifact: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("download durable snapshot artifact returned %d", resp.StatusCode)
}
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("create restore artifact dir %q: %w", filepath.Dir(targetPath), err)
}
out, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("create restore artifact %q: %w", targetPath, err)
}
defer func() { _ = out.Close() }()
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("write restore artifact %q: %w", targetPath, err)
}
return nil
}