feat: local first snapshot implementation end to end

This commit is contained in:
Harivansh Rathi 2026-04-10 21:32:09 +00:00
parent fc21e897ea
commit 30282928f5
4 changed files with 279 additions and 60 deletions

View file

@ -11,12 +11,16 @@ import (
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract"
)
const localSnapshotRestoreUnavailablePrefix = "local snapshot restore unavailable"
func (d *Daemon) CreateSnapshot(ctx context.Context, machineID contracthost.MachineID, req contracthost.CreateSnapshotRequest) (*contracthost.CreateSnapshotResponse, error) {
unlock := d.lockMachine(machineID)
defer unlock()
@ -193,20 +197,31 @@ func (d *Daemon) UploadSnapshot(ctx context.Context, snapshotID contracthost.Sna
response := &contracthost.UploadSnapshotResponse{
Artifacts: make([]contracthost.UploadedSnapshotArtifact, 0, len(req.Artifacts)),
}
for _, upload := range req.Artifacts {
artifact, ok := artifactIndex[upload.ArtifactID]
if !ok {
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID)
}
completedParts, err := uploadSnapshotArtifact(ctx, artifact.LocalPath, upload.Parts)
if err != nil {
return nil, fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err)
}
response.Artifacts = append(response.Artifacts, contracthost.UploadedSnapshotArtifact{
ArtifactID: upload.ArtifactID,
CompletedParts: completedParts,
uploads := make([]contracthost.UploadedSnapshotArtifact, len(req.Artifacts))
group, groupCtx := errgroup.WithContext(ctx)
for i, upload := range req.Artifacts {
i := i
upload := upload
group.Go(func() error {
artifact, ok := artifactIndex[upload.ArtifactID]
if !ok {
return fmt.Errorf("snapshot %q artifact %q not found", snapshotID, upload.ArtifactID)
}
completedParts, err := uploadSnapshotArtifact(groupCtx, artifact.LocalPath, upload.Parts)
if err != nil {
return fmt.Errorf("upload snapshot artifact %q: %w", upload.ArtifactID, err)
}
uploads[i] = contracthost.UploadedSnapshotArtifact{
ArtifactID: upload.ArtifactID,
CompletedParts: completedParts,
}
return nil
})
}
if err := group.Wait(); err != nil {
return nil, err
}
response.Artifacts = append(response.Artifacts, uploads...)
return response, nil
}
@ -215,12 +230,18 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
if err := validateMachineID(req.MachineID); err != nil {
return nil, err
}
if req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID)
}
if err := validateArtifactRef(req.Artifact); err != nil {
return nil, err
}
if req.LocalSnapshot == nil && req.Snapshot == nil {
return nil, fmt.Errorf("restore request must include local_snapshot or snapshot")
}
if req.LocalSnapshot != nil && req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
}
if req.Snapshot != nil && req.Snapshot.SnapshotID != "" && req.Snapshot.SnapshotID != snapshotID {
return nil, fmt.Errorf("snapshot id mismatch: path=%q payload=%q", snapshotID, req.Snapshot.SnapshotID)
}
if err := validateGuestConfig(req.GuestConfig); err != nil {
return nil, err
}
@ -258,30 +279,18 @@ func (d *Daemon) RestoreSnapshot(ctx context.Context, snapshotID contracthost.Sn
if err != nil {
return nil, err
}
restoreNetwork, err := d.resolveRestoreNetwork(ctx, snapshotID, req.Snapshot)
restoredArtifacts, restoreNetwork, cleanupRestoreArtifacts, err := d.prepareRestoreArtifacts(ctx, snapshotID, req, usedNetworks)
if err != nil {
clearOperation = true
return nil, err
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
clearOperation = true
return nil, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
}
defer cleanupRestoreArtifacts()
artifact, err := d.ensureArtifact(ctx, req.Artifact)
if err != nil {
clearOperation = true
return nil, fmt.Errorf("ensure artifact for restore: %w", err)
}
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
restoredArtifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil {
_ = os.RemoveAll(stagingDir)
clearOperation = true
return nil, fmt.Errorf("download durable snapshot artifacts: %w", err)
}
defer func() { _ = os.RemoveAll(stagingDir) }()
// COW-copy system disk from snapshot to new machine's disk dir.
newSystemDiskPath := d.systemVolumePath(req.MachineID)
if err := os.MkdirAll(filepath.Dir(newSystemDiskPath), 0o755); err != nil {
@ -515,19 +524,59 @@ func restoredUserDiskIndex(name string) (int, bool) {
return index, true
}
func (d *Daemon) resolveRestoreNetwork(ctx context.Context, snapshotID contracthost.SnapshotID, spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
if network, err := restoreNetworkFromDurableSpec(spec); err == nil {
return network, nil
func (d *Daemon) prepareRestoreArtifacts(ctx context.Context, snapshotID contracthost.SnapshotID, req contracthost.RestoreSnapshotRequest, usedNetworks []firecracker.NetworkAllocation) (map[string]restoredSnapshotArtifact, firecracker.NetworkAllocation, func(), error) {
if req.LocalSnapshot != nil {
if req.LocalSnapshot.SnapshotID != "" && req.LocalSnapshot.SnapshotID != snapshotID {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("local snapshot id mismatch: path=%q payload=%q", snapshotID, req.LocalSnapshot.SnapshotID)
}
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil {
if err == store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, "snapshot is not present on this host")
}
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
restoreNetwork, err := restoreNetworkFromSnapshot(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, fmt.Sprintf("restore network is still in use on this host (runtime_host=%s tap_device=%s)", restoreNetwork.GuestIP(), restoreNetwork.TapName))
}
artifacts, err := localSnapshotArtifacts(snapshot)
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, localSnapshotRestoreUnavailable(snapshotID, err.Error())
}
return artifacts, restoreNetwork, func() {}, nil
}
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err == nil {
return restoreNetworkFromSnapshot(snapshot)
if req.Snapshot == nil {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("durable snapshot spec is required")
}
if err != store.ErrNotFound {
return firecracker.NetworkAllocation{}, err
restoreNetwork, err := restoreNetworkFromDurableSpec(*req.Snapshot)
if err != nil {
snapshot, lookupErr := d.store.GetSnapshot(ctx, snapshotID)
if lookupErr == nil {
restoreNetwork, err = restoreNetworkFromSnapshot(snapshot)
} else if lookupErr != store.ErrNotFound {
return nil, firecracker.NetworkAllocation{}, func() {}, lookupErr
}
if err != nil {
return nil, firecracker.NetworkAllocation{}, func() {}, err
}
}
return firecracker.NetworkAllocation{}, fmt.Errorf("snapshot %q is missing restore network metadata", snapshotID)
if networkAllocationInUse(restoreNetwork, usedNetworks) {
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("restore network for snapshot %q is still in use on this host (runtime_host=%s tap_device=%s)", snapshotID, restoreNetwork.GuestIP(), restoreNetwork.TapName)
}
stagingDir := filepath.Join(d.config.SnapshotsDir, string(snapshotID), "restores", string(req.MachineID))
artifacts, err := downloadDurableSnapshotArtifacts(ctx, stagingDir, req.Snapshot.Artifacts)
if err != nil {
_ = os.RemoveAll(stagingDir)
return nil, firecracker.NetworkAllocation{}, func() {}, fmt.Errorf("download durable snapshot artifacts: %w", err)
}
return artifacts, restoreNetwork, func() {
_ = os.RemoveAll(stagingDir)
}, nil
}
func restoreNetworkFromDurableSpec(spec contracthost.DurableSnapshotSpec) (firecracker.NetworkAllocation, error) {
@ -555,6 +604,32 @@ func restoreNetworkFromSnapshot(snap *model.SnapshotRecord) (firecracker.Network
return network, nil
}
func localSnapshotArtifacts(snapshot *model.SnapshotRecord) (map[string]restoredSnapshotArtifact, error) {
if snapshot == nil {
return nil, fmt.Errorf("snapshot is required")
}
restored := make(map[string]restoredSnapshotArtifact, len(snapshot.Artifacts))
for _, artifact := range snapshot.Artifacts {
if strings.TrimSpace(artifact.LocalPath) == "" {
return nil, fmt.Errorf("snapshot %q artifact %q is missing a local path", snapshot.ID, artifact.ID)
}
if _, err := os.Stat(artifact.LocalPath); err != nil {
return nil, fmt.Errorf("snapshot %q artifact %q is unavailable at %q: %w", snapshot.ID, artifact.ID, artifact.LocalPath, err)
}
restored[artifact.Name] = restoredSnapshotArtifact{
Artifact: contracthost.SnapshotArtifact{
ID: artifact.ID,
Kind: artifact.Kind,
Name: artifact.Name,
SizeBytes: artifact.SizeBytes,
SHA256Hex: artifact.SHA256Hex,
},
LocalPath: artifact.LocalPath,
}
}
return restored, nil
}
func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecracker.NetworkAllocation) bool {
targetTap := strings.TrimSpace(target.TapName)
for _, network := range used {
@ -568,6 +643,14 @@ func networkAllocationInUse(target firecracker.NetworkAllocation, used []firecra
return false
}
func localSnapshotRestoreUnavailable(snapshotID contracthost.SnapshotID, message string) error {
message = strings.TrimSpace(message)
if message == "" {
message = "local restore is unavailable"
}
return fmt.Errorf("%s: snapshot %q %s", localSnapshotRestoreUnavailablePrefix, snapshotID, message)
}
// moveFile copies src to dst then removes src. Works across filesystem boundaries
// unlike os.Rename, which is needed when moving files out of /proc/<pid>/root/.
func moveFile(src, dst string) error {