feat: snapshot upload async allows restore

This commit is contained in:
Harivansh Rathi 2026-04-10 22:48:45 +00:00
parent 30282928f5
commit 1e7829a974
3 changed files with 108 additions and 0 deletions

View file

@ -641,6 +641,55 @@ func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) {
} }
} }
func TestGetSnapshotArtifactReturnsLocalArtifactPath(t *testing.T) {
root := t.TempDir()
cfg := testConfig(root)
fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath)
if err != nil {
t.Fatalf("create file store: %v", err)
}
hostDaemon, err := New(cfg, fileStore, &fakeRuntime{})
if err != nil {
t.Fatalf("create daemon: %v", err)
}
snapshotDir := filepath.Join(root, "snapshots", "snap-artifact")
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
t.Fatalf("create snapshot dir: %v", err)
}
memoryPath := filepath.Join(snapshotDir, "memory.bin")
if err := os.WriteFile(memoryPath, []byte("mem"), 0o644); err != nil {
t.Fatalf("write memory: %v", err)
}
if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{
ID: "snap-artifact",
MachineID: "source",
MemFilePath: memoryPath,
StateFilePath: filepath.Join(snapshotDir, "vmstate.bin"),
Artifacts: []model.SnapshotArtifactRecord{
{ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", LocalPath: memoryPath, SizeBytes: 3},
},
CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("create snapshot: %v", err)
}
artifact, err := hostDaemon.GetSnapshotArtifact(context.Background(), "snap-artifact", "memory")
if err != nil {
t.Fatalf("GetSnapshotArtifact returned error: %v", err)
}
if artifact == nil {
t.Fatalf("GetSnapshotArtifact returned nil artifact")
}
if artifact.Name != "memory.bin" {
t.Fatalf("artifact name = %q, want memory.bin", artifact.Name)
}
if artifact.Path != memoryPath {
t.Fatalf("artifact path = %q, want %q", artifact.Path, memoryPath)
}
}
func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) {
root := t.TempDir() root := t.TempDir()
cfg := testConfig(root) cfg := testConfig(root)

View file

@ -14,6 +14,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/getcompanion-ai/computer-host/internal/firecracker" "github.com/getcompanion-ai/computer-host/internal/firecracker"
"github.com/getcompanion-ai/computer-host/internal/httpapi"
"github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/model"
"github.com/getcompanion-ai/computer-host/internal/store" "github.com/getcompanion-ai/computer-host/internal/store"
contracthost "github.com/getcompanion-ai/computer-host/contract" contracthost "github.com/getcompanion-ai/computer-host/contract"
@ -441,6 +442,37 @@ func (d *Daemon) GetSnapshot(ctx context.Context, snapshotID contracthost.Snapsh
return &contracthost.GetSnapshotResponse{Snapshot: snapshotToContract(*snap)}, nil return &contracthost.GetSnapshotResponse{Snapshot: snapshotToContract(*snap)}, nil
} }
func (d *Daemon) GetSnapshotArtifact(ctx context.Context, snapshotID contracthost.SnapshotID, artifactID string) (*httpapi.SnapshotArtifactContent, error) {
snapshot, err := d.store.GetSnapshot(ctx, snapshotID)
if err != nil {
return nil, err
}
artifactID = strings.TrimSpace(artifactID)
if artifactID == "" {
return nil, fmt.Errorf("snapshot artifact id is required")
}
for _, artifact := range snapshot.Artifacts {
if artifact.ID != artifactID {
continue
}
path := strings.TrimSpace(artifact.LocalPath)
if path == "" {
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, artifactID)
}
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, artifactID)
}
return nil, fmt.Errorf("stat snapshot %q artifact %q: %w", snapshotID, artifactID, err)
}
return &httpapi.SnapshotArtifactContent{
Name: artifact.Name,
Path: path,
}, nil
}
return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, artifactID)
}
func (d *Daemon) ListSnapshots(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) { func (d *Daemon) ListSnapshots(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) {
records, err := d.store.ListSnapshotsByMachine(ctx, machineID) records, err := d.store.ListSnapshotsByMachine(ctx, machineID)
if err != nil { if err != nil {

View file

@ -23,6 +23,7 @@ type Service interface {
UploadSnapshot(context.Context, contracthost.SnapshotID, contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error) UploadSnapshot(context.Context, contracthost.SnapshotID, contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error)
ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error)
GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error) GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error)
GetSnapshotArtifact(context.Context, contracthost.SnapshotID, string) (*SnapshotArtifactContent, error)
DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error
RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error)
CreatePublishedPort(context.Context, contracthost.MachineID, contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error) CreatePublishedPort(context.Context, contracthost.MachineID, contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error)
@ -30,6 +31,11 @@ type Service interface {
DeletePublishedPort(context.Context, contracthost.MachineID, contracthost.PublishedPortID) error DeletePublishedPort(context.Context, contracthost.MachineID, contracthost.PublishedPortID) error
} }
type SnapshotArtifactContent struct {
Name string
Path string
}
type Handler struct { type Handler struct {
service Service service Service
} }
@ -298,6 +304,27 @@ func (h *Handler) handleSnapshot(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(parts) == 3 && parts[1] == "artifacts" {
if r.Method != http.MethodGet {
writeMethodNotAllowed(w)
return
}
artifact, err := h.service.GetSnapshotArtifact(r.Context(), snapshotID, parts[2])
if err != nil {
writeError(w, statusForError(err), err)
return
}
if artifact == nil {
writeError(w, http.StatusNotFound, fmt.Errorf("snapshot artifact %q not found", parts[2]))
return
}
if artifact.Name != "" {
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", artifact.Name))
}
http.ServeFile(w, r, artifact.Path)
return
}
writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) writeError(w, http.StatusNotFound, fmt.Errorf("route not found"))
} }