mirror of
https://github.com/getcompanion-ai/computer-host.git
synced 2026-04-15 10:05:13 +00:00
194 lines
6.1 KiB
Go
194 lines
6.1 KiB
Go
package daemon
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/getcompanion-ai/computer-host/internal/model"
|
|
contracthost "github.com/getcompanion-ai/computer-host/contract"
|
|
)
|
|
|
|
type restoredSnapshotArtifact struct {
|
|
Artifact contracthost.SnapshotArtifact
|
|
LocalPath string
|
|
}
|
|
|
|
func buildSnapshotArtifacts(memoryPath, vmstatePath string, diskPaths []string) ([]model.SnapshotArtifactRecord, error) {
|
|
artifacts := make([]model.SnapshotArtifactRecord, 0, len(diskPaths)+2)
|
|
|
|
memoryArtifact, err := snapshotArtifactRecord("memory", contracthost.SnapshotArtifactKindMemory, filepath.Base(memoryPath), memoryPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
artifacts = append(artifacts, memoryArtifact)
|
|
|
|
vmstateArtifact, err := snapshotArtifactRecord("vmstate", contracthost.SnapshotArtifactKindVMState, filepath.Base(vmstatePath), vmstatePath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
artifacts = append(artifacts, vmstateArtifact)
|
|
|
|
for _, diskPath := range diskPaths {
|
|
base := filepath.Base(diskPath)
|
|
diskArtifact, err := snapshotArtifactRecord("disk-"+strings.TrimSuffix(base, filepath.Ext(base)), contracthost.SnapshotArtifactKindDisk, base, diskPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
artifacts = append(artifacts, diskArtifact)
|
|
}
|
|
|
|
sort.Slice(artifacts, func(i, j int) bool {
|
|
return artifacts[i].ID < artifacts[j].ID
|
|
})
|
|
return artifacts, nil
|
|
}
|
|
|
|
func snapshotArtifactRecord(id string, kind contracthost.SnapshotArtifactKind, name, path string) (model.SnapshotArtifactRecord, error) {
|
|
size, err := fileSize(path)
|
|
if err != nil {
|
|
return model.SnapshotArtifactRecord{}, err
|
|
}
|
|
sum, err := sha256File(path)
|
|
if err != nil {
|
|
return model.SnapshotArtifactRecord{}, err
|
|
}
|
|
return model.SnapshotArtifactRecord{
|
|
ID: id,
|
|
Kind: kind,
|
|
Name: name,
|
|
LocalPath: path,
|
|
SizeBytes: size,
|
|
SHA256Hex: sum,
|
|
}, nil
|
|
}
|
|
|
|
func sha256File(path string) (string, error) {
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
return "", fmt.Errorf("open %q for sha256: %w", path, err)
|
|
}
|
|
defer func() { _ = file.Close() }()
|
|
|
|
hash := sha256.New()
|
|
if _, err := io.Copy(hash, file); err != nil {
|
|
return "", fmt.Errorf("hash %q: %w", path, err)
|
|
}
|
|
return hex.EncodeToString(hash.Sum(nil)), nil
|
|
}
|
|
|
|
func uploadSnapshotArtifact(ctx context.Context, localPath string, parts []contracthost.SnapshotUploadPart) ([]contracthost.UploadedSnapshotPart, error) {
|
|
if len(parts) == 0 {
|
|
return nil, fmt.Errorf("upload session has no parts")
|
|
}
|
|
|
|
file, err := os.Open(localPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open artifact %q: %w", localPath, err)
|
|
}
|
|
defer func() { _ = file.Close() }()
|
|
|
|
client := &http.Client{}
|
|
completed := make([]contracthost.UploadedSnapshotPart, 0, len(parts))
|
|
for _, part := range parts {
|
|
reader := io.NewSectionReader(file, part.OffsetBytes, part.SizeBytes)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, part.UploadURL, io.NopCloser(reader))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build upload part %d: %w", part.PartNumber, err)
|
|
}
|
|
req.ContentLength = part.SizeBytes
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("upload part %d: %w", part.PartNumber, err)
|
|
}
|
|
_ = resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("upload part %d returned %d", part.PartNumber, resp.StatusCode)
|
|
}
|
|
etag := strings.TrimSpace(resp.Header.Get("ETag"))
|
|
if etag == "" {
|
|
return nil, fmt.Errorf("upload part %d returned empty etag", part.PartNumber)
|
|
}
|
|
completed = append(completed, contracthost.UploadedSnapshotPart{
|
|
PartNumber: part.PartNumber,
|
|
ETag: etag,
|
|
})
|
|
}
|
|
sort.Slice(completed, func(i, j int) bool {
|
|
return completed[i].PartNumber < completed[j].PartNumber
|
|
})
|
|
return completed, nil
|
|
}
|
|
|
|
func downloadDurableSnapshotArtifacts(ctx context.Context, root string, artifacts []contracthost.SnapshotArtifact) (map[string]restoredSnapshotArtifact, error) {
|
|
if len(artifacts) == 0 {
|
|
return nil, fmt.Errorf("restore snapshot is missing artifacts")
|
|
}
|
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
|
return nil, fmt.Errorf("create restore staging dir %q: %w", root, err)
|
|
}
|
|
|
|
client := &http.Client{}
|
|
restored := make(map[string]restoredSnapshotArtifact, len(artifacts))
|
|
for _, artifact := range artifacts {
|
|
if strings.TrimSpace(artifact.DownloadURL) == "" {
|
|
return nil, fmt.Errorf("artifact %q is missing download url", artifact.ID)
|
|
}
|
|
localPath := filepath.Join(root, artifact.Name)
|
|
if err := downloadSnapshotArtifact(ctx, client, artifact.DownloadURL, localPath); err != nil {
|
|
return nil, err
|
|
}
|
|
if expectedSHA := strings.TrimSpace(artifact.SHA256Hex); expectedSHA != "" {
|
|
actualSHA, err := sha256File(localPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !strings.EqualFold(actualSHA, expectedSHA) {
|
|
_ = os.Remove(localPath)
|
|
return nil, fmt.Errorf("restore artifact %q sha256 mismatch: got %s want %s", artifact.Name, actualSHA, expectedSHA)
|
|
}
|
|
}
|
|
restored[artifact.Name] = restoredSnapshotArtifact{
|
|
Artifact: artifact,
|
|
LocalPath: localPath,
|
|
}
|
|
}
|
|
return restored, nil
|
|
}
|
|
|
|
func downloadSnapshotArtifact(ctx context.Context, client *http.Client, sourceURL, targetPath string) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("build restore download request: %w", err)
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("download durable snapshot artifact: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return fmt.Errorf("download durable snapshot artifact returned %d", resp.StatusCode)
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
|
return fmt.Errorf("create restore artifact dir %q: %w", filepath.Dir(targetPath), err)
|
|
}
|
|
out, err := os.Create(targetPath)
|
|
if err != nil {
|
|
return fmt.Errorf("create restore artifact %q: %w", targetPath, err)
|
|
}
|
|
defer func() { _ = out.Close() }()
|
|
|
|
if _, err := io.Copy(out, resp.Body); err != nil {
|
|
return fmt.Errorf("write restore artifact %q: %w", targetPath, err)
|
|
}
|
|
return nil
|
|
}
|