sep/app.py
2026-01-23 15:06:41 -05:00

241 lines
7.1 KiB
Python

"""
Audio Separator API
Simple FastAPI service for stem separation using audio-separator
"""
import os
import uuid
import shutil
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, UploadFile, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
app = FastAPI(
title="Audio Separator API",
description="Separate audio into vocal and instrumental stems using ML models",
version="1.0.0",
)
# Configuration
UPLOAD_DIR = Path("/tmp/audio-separator/uploads")
OUTPUT_DIR = Path("/tmp/audio-separator/outputs")
MODEL_DIR = Path("/tmp/audio-separator/models")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# Lazy load separator to avoid import issues if CUDA not available
_separator = None
def get_separator():
"""Lazy initialization of separator with CUDA if available."""
global _separator
if _separator is None:
from audio_separator.separator import Separator
# Check CUDA availability
use_cuda = False
try:
import torch
use_cuda = torch.cuda.is_available()
if use_cuda:
print(f"CUDA available: {torch.cuda.get_device_name(0)}")
else:
print("CUDA not available, using CPU")
except Exception as e:
print(f"Error checking CUDA: {e}")
_separator = Separator(
output_dir=str(OUTPUT_DIR),
model_file_dir=str(MODEL_DIR),
use_cuda=use_cuda,
output_format="mp3",
)
return _separator
class SeparationRequest(BaseModel):
output_format: Optional[str] = "mp3"
model_name: Optional[str] = None
class SeparationResponse(BaseModel):
job_id: str
status: str
vocals_url: Optional[str] = None
instrumental_url: Optional[str] = None
message: Optional[str] = None
class HealthResponse(BaseModel):
status: str
cuda_available: bool
cuda_device: Optional[str] = None
def cleanup_files(file_paths: list[str], delay_seconds: int = 300):
"""Background task to cleanup temporary files after a delay."""
import time
time.sleep(delay_seconds)
for path in file_paths:
try:
if os.path.exists(path):
os.remove(path)
except Exception as e:
print(f"Error cleaning up {path}: {e}")
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Check API health and CUDA availability."""
cuda_available = False
cuda_device = None
try:
import torch
cuda_available = torch.cuda.is_available()
if cuda_available:
cuda_device = torch.cuda.get_device_name(0)
except Exception:
pass
return HealthResponse(
status="healthy",
cuda_available=cuda_available,
cuda_device=cuda_device,
)
@app.post("/separate", response_model=SeparationResponse)
async def separate_audio(
file: UploadFile,
background_tasks: BackgroundTasks,
output_format: str = "mp3",
model_name: Optional[str] = None,
):
"""
Separate audio file into vocal and instrumental stems.
- **file**: Audio file (mp3, wav, flac, m4a, etc.)
- **output_format**: Output format (mp3, wav, flac) - default: mp3
- **model_name**: Model to use (optional, uses default if not specified)
Returns URLs to download the separated stems.
"""
job_id = str(uuid.uuid4())[:8]
# Validate file
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
allowed_extensions = {".mp3", ".wav", ".flac", ".m4a", ".ogg", ".wma", ".aac"}
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. Allowed: {allowed_extensions}"
)
# Save uploaded file
input_path = UPLOAD_DIR / f"{job_id}_{file.filename}"
try:
with open(input_path, "wb") as f:
shutil.copyfileobj(file.file, f)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save file: {e}")
# Run separation
try:
separator = get_separator()
# Update output format if specified
separator.output_format = output_format
# Load model if specified
if model_name:
separator.load_model(model_name)
else:
separator.load_model()
# Run separation
output_files = separator.separate(str(input_path))
if not output_files or len(output_files) < 2:
raise HTTPException(status_code=500, detail="Separation failed - no output files")
# Find vocals and instrumental files
vocals_path = None
instrumental_path = None
for f in output_files:
f_lower = f.lower()
if "vocal" in f_lower:
vocals_path = f
elif "instrumental" in f_lower or "instrum" in f_lower:
instrumental_path = f
# Schedule cleanup of files after 5 minutes
files_to_cleanup = [str(input_path)]
if vocals_path:
files_to_cleanup.append(vocals_path)
if instrumental_path:
files_to_cleanup.append(instrumental_path)
background_tasks.add_task(cleanup_files, files_to_cleanup, 300)
return SeparationResponse(
job_id=job_id,
status="completed",
vocals_url=f"/download/{Path(vocals_path).name}" if vocals_path else None,
instrumental_url=f"/download/{Path(instrumental_path).name}" if instrumental_path else None,
)
except Exception as e:
# Cleanup input file on error
if input_path.exists():
input_path.unlink()
raise HTTPException(status_code=500, detail=f"Separation failed: {e}")
@app.get("/download/{filename}")
async def download_file(filename: str):
"""Download a separated stem file."""
file_path = OUTPUT_DIR / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
# Determine media type
media_types = {
".mp3": "audio/mpeg",
".wav": "audio/wav",
".flac": "audio/flac",
}
media_type = media_types.get(file_path.suffix.lower(), "application/octet-stream")
return FileResponse(
path=str(file_path),
filename=filename,
media_type=media_type,
)
@app.get("/models")
async def list_models():
"""List available separation models."""
models = [
{"name": "BS-RoFormer (default)", "id": None, "description": "Best quality, slower"},
{"name": "UVR_MDXNET_KARA_2", "id": "UVR_MDXNET_KARA_2", "description": "Fast, good for karaoke"},
{"name": "UVR-MDX-NET-Inst_HQ_3", "id": "UVR-MDX-NET-Inst_HQ_3", "description": "High quality instrumentals"},
{"name": "Kim_Vocal_2", "id": "Kim_Vocal_2", "description": "Good vocal isolation"},
]
return {"models": models}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)