From 1e7829a974122c86f231da72a1b47a97cc8c9149 Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Fri, 10 Apr 2026 22:48:45 +0000 Subject: [PATCH] feat: snapshot upload async allows restore --- internal/daemon/daemon_test.go | 49 ++++++++++++++++++++++++++++++++++ internal/daemon/snapshot.go | 32 ++++++++++++++++++++++ internal/httpapi/handlers.go | 27 +++++++++++++++++++ 3 files changed, 108 insertions(+) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 8e9446b..dde8bf6 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -641,6 +641,55 @@ func TestRestoreSnapshotUsesLocalSnapshotArtifacts(t *testing.T) { } } +func TestGetSnapshotArtifactReturnsLocalArtifactPath(t *testing.T) { + root := t.TempDir() + cfg := testConfig(root) + fileStore, err := store.NewFileStore(cfg.StatePath, cfg.OperationsPath) + if err != nil { + t.Fatalf("create file store: %v", err) + } + + hostDaemon, err := New(cfg, fileStore, &fakeRuntime{}) + if err != nil { + t.Fatalf("create daemon: %v", err) + } + + snapshotDir := filepath.Join(root, "snapshots", "snap-artifact") + if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + t.Fatalf("create snapshot dir: %v", err) + } + memoryPath := filepath.Join(snapshotDir, "memory.bin") + if err := os.WriteFile(memoryPath, []byte("mem"), 0o644); err != nil { + t.Fatalf("write memory: %v", err) + } + if err := fileStore.CreateSnapshot(context.Background(), model.SnapshotRecord{ + ID: "snap-artifact", + MachineID: "source", + MemFilePath: memoryPath, + StateFilePath: filepath.Join(snapshotDir, "vmstate.bin"), + Artifacts: []model.SnapshotArtifactRecord{ + {ID: "memory", Kind: contracthost.SnapshotArtifactKindMemory, Name: "memory.bin", LocalPath: memoryPath, SizeBytes: 3}, + }, + CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create snapshot: %v", err) + } + + artifact, err := hostDaemon.GetSnapshotArtifact(context.Background(), "snap-artifact", "memory") + if err != nil { + t.Fatalf("GetSnapshotArtifact returned error: %v", err) + } + if artifact == nil { + t.Fatalf("GetSnapshotArtifact returned nil artifact") + } + if artifact.Name != "memory.bin" { + t.Fatalf("artifact name = %q, want memory.bin", artifact.Name) + } + if artifact.Path != memoryPath { + t.Fatalf("artifact path = %q, want %q", artifact.Path, memoryPath) + } +} + func TestRestoreSnapshotUsesDurableSnapshotSpec(t *testing.T) { root := t.TempDir() cfg := testConfig(root) diff --git a/internal/daemon/snapshot.go b/internal/daemon/snapshot.go index 0255d0e..ae72130 100644 --- a/internal/daemon/snapshot.go +++ b/internal/daemon/snapshot.go @@ -14,6 +14,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/getcompanion-ai/computer-host/internal/firecracker" + "github.com/getcompanion-ai/computer-host/internal/httpapi" "github.com/getcompanion-ai/computer-host/internal/model" "github.com/getcompanion-ai/computer-host/internal/store" contracthost "github.com/getcompanion-ai/computer-host/contract" @@ -441,6 +442,37 @@ func (d *Daemon) GetSnapshot(ctx context.Context, snapshotID contracthost.Snapsh return &contracthost.GetSnapshotResponse{Snapshot: snapshotToContract(*snap)}, nil } +func (d *Daemon) GetSnapshotArtifact(ctx context.Context, snapshotID contracthost.SnapshotID, artifactID string) (*httpapi.SnapshotArtifactContent, error) { + snapshot, err := d.store.GetSnapshot(ctx, snapshotID) + if err != nil { + return nil, err + } + artifactID = strings.TrimSpace(artifactID) + if artifactID == "" { + return nil, fmt.Errorf("snapshot artifact id is required") + } + for _, artifact := range snapshot.Artifacts { + if artifact.ID != artifactID { + continue + } + path := strings.TrimSpace(artifact.LocalPath) + if path == "" { + return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, artifactID) + } + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, artifactID) + } + return nil, fmt.Errorf("stat snapshot %q artifact %q: %w", snapshotID, artifactID, err) + } + return &httpapi.SnapshotArtifactContent{ + Name: artifact.Name, + Path: path, + }, nil + } + return nil, fmt.Errorf("snapshot %q artifact %q not found", snapshotID, artifactID) +} + func (d *Daemon) ListSnapshots(ctx context.Context, machineID contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) { records, err := d.store.ListSnapshotsByMachine(ctx, machineID) if err != nil { diff --git a/internal/httpapi/handlers.go b/internal/httpapi/handlers.go index 8ac3498..de4a308 100644 --- a/internal/httpapi/handlers.go +++ b/internal/httpapi/handlers.go @@ -23,6 +23,7 @@ type Service interface { UploadSnapshot(context.Context, contracthost.SnapshotID, contracthost.UploadSnapshotRequest) (*contracthost.UploadSnapshotResponse, error) ListSnapshots(context.Context, contracthost.MachineID) (*contracthost.ListSnapshotsResponse, error) GetSnapshot(context.Context, contracthost.SnapshotID) (*contracthost.GetSnapshotResponse, error) + GetSnapshotArtifact(context.Context, contracthost.SnapshotID, string) (*SnapshotArtifactContent, error) DeleteSnapshotByID(context.Context, contracthost.SnapshotID) error RestoreSnapshot(context.Context, contracthost.SnapshotID, contracthost.RestoreSnapshotRequest) (*contracthost.RestoreSnapshotResponse, error) CreatePublishedPort(context.Context, contracthost.MachineID, contracthost.CreatePublishedPortRequest) (*contracthost.CreatePublishedPortResponse, error) @@ -30,6 +31,11 @@ type Service interface { DeletePublishedPort(context.Context, contracthost.MachineID, contracthost.PublishedPortID) error } +type SnapshotArtifactContent struct { + Name string + Path string +} + type Handler struct { service Service } @@ -298,6 +304,27 @@ func (h *Handler) handleSnapshot(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 3 && parts[1] == "artifacts" { + if r.Method != http.MethodGet { + writeMethodNotAllowed(w) + return + } + artifact, err := h.service.GetSnapshotArtifact(r.Context(), snapshotID, parts[2]) + if err != nil { + writeError(w, statusForError(err), err) + return + } + if artifact == nil { + writeError(w, http.StatusNotFound, fmt.Errorf("snapshot artifact %q not found", parts[2])) + return + } + if artifact.Name != "" { + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", artifact.Name)) + } + http.ServeFile(w, r, artifact.Path) + return + } + writeError(w, http.StatusNotFound, fmt.Errorf("route not found")) }