mirror of
https://github.com/harivansh-afk/url-shortner.git
synced 2026-04-15 05:02:12 +00:00
nginx load balancer and api
This commit is contained in:
parent
1a980a7a70
commit
3469c7c83e
5 changed files with 502 additions and 0 deletions
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
52
app/encoding.py
Normal file
52
app/encoding.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
Base62 encoding for short URL generation.
|
||||
|
||||
- Uses 0-9, A-Z, a-z (62 characters)
|
||||
- URL-safe (no special characters)
|
||||
- More compact than hex (base16) or base64
|
||||
|
||||
Length vs Capacity:
|
||||
- 6 chars: 62^6 = 56.8 billion unique URLs
|
||||
- 7 chars: 62^7 = 3.5 trillion unique URLs
|
||||
"""
|
||||
|
||||
CHARSET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
BASE = len(CHARSET) # 62
|
||||
|
||||
|
||||
def base62_encode(num: int) -> str:
|
||||
"""
|
||||
Encode an integer to a base62 string.
|
||||
"""
|
||||
if num < 0:
|
||||
raise ValueError("Cannot encode negative numbers")
|
||||
if num == 0:
|
||||
return CHARSET[0]
|
||||
|
||||
result = []
|
||||
while num:
|
||||
result.append(CHARSET[num % BASE])
|
||||
num //= BASE
|
||||
|
||||
return "".join(reversed(result))
|
||||
|
||||
|
||||
def base62_decode(encoded: str) -> int:
|
||||
"""
|
||||
Decode a base62 string back to an integer.
|
||||
"""
|
||||
if not encoded:
|
||||
raise ValueError("Cannot decode empty string")
|
||||
|
||||
num = 0
|
||||
for char in encoded:
|
||||
if char not in CHARSET:
|
||||
raise ValueError(f"Invalid character: {char}")
|
||||
num = num * BASE + CHARSET.index(char)
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def pad_to_length(encoded: str, length: int = 7) -> str:
|
||||
"""Pad encoded string to minimum length with leading zeros."""
|
||||
return encoded.zfill(length)[-length:] if len(encoded) < length else encoded
|
||||
253
app/main.py
Normal file
253
app/main.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""
|
||||
URL Shortener API
|
||||
|
||||
Endpoints:
|
||||
- POST /shorten - Create a short URL
|
||||
- GET /{code} - Redirect to original URL
|
||||
- GET /stats/{code} - Get click statistics
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as redis
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from app.encoding import base62_encode
|
||||
from app.snowflake import init_generator, generate_id
|
||||
|
||||
|
||||
# Configuration from environment
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://urlshortner:localdev@localhost:5432/urlshortner")
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
|
||||
MACHINE_ID = int(os.getenv("MACHINE_ID", "1"))
|
||||
BASE_URL = os.getenv("BASE_URL", "http://localhost")
|
||||
|
||||
# Cache TTL in seconds (1 hour)
|
||||
CACHE_TTL = 3600
|
||||
|
||||
# Global connections
|
||||
db_pool: asyncpg.Pool | None = None
|
||||
redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage database and Redis connections."""
|
||||
global db_pool, redis_client
|
||||
|
||||
# Startup
|
||||
init_generator(MACHINE_ID)
|
||||
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20)
|
||||
redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
print(f"[Startup] Connected to PostgreSQL and Redis. Machine ID: {MACHINE_ID}")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if db_pool:
|
||||
await db_pool.close()
|
||||
if redis_client:
|
||||
await redis_client.close()
|
||||
|
||||
print("[Shutdown] Connections closed.")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="URL Shortener",
|
||||
description="Distributed URL shortening service",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class ShortenRequest(BaseModel):
|
||||
url: HttpUrl
|
||||
custom_code: str | None = None # Optional custom short code
|
||||
|
||||
|
||||
class ShortenResponse(BaseModel):
|
||||
short_url: str
|
||||
short_code: str
|
||||
original_url: str
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
short_code: str
|
||||
original_url: str
|
||||
click_count: int
|
||||
created_at: str
|
||||
|
||||
|
||||
# Endpoints
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check for load balancer."""
|
||||
return {"status": "healthy", "machine_id": MACHINE_ID}
|
||||
|
||||
|
||||
@app.post("/shorten", response_model=ShortenResponse)
|
||||
async def shorten_url(request: ShortenRequest, req: Request):
|
||||
"""
|
||||
Create a shortened URL.
|
||||
|
||||
Process:
|
||||
1. Generate unique ID using Snowflake
|
||||
2. Encode as base62 for short code
|
||||
3. Store in PostgreSQL
|
||||
4. Cache in Redis
|
||||
"""
|
||||
original_url = str(request.url)
|
||||
|
||||
# Validate URL has a valid domain
|
||||
parsed = urlparse(original_url)
|
||||
if not parsed.netloc:
|
||||
raise HTTPException(status_code=400, detail="Invalid URL")
|
||||
|
||||
# Generate short code
|
||||
if request.custom_code:
|
||||
short_code = request.custom_code
|
||||
# Check if custom code already exists
|
||||
existing = await redis_client.get(f"url:{short_code}")
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Custom code already in use")
|
||||
else:
|
||||
# Generate using Snowflake + base62
|
||||
snowflake_id = generate_id()
|
||||
short_code = base62_encode(snowflake_id)
|
||||
|
||||
# Get client info
|
||||
client_ip = req.headers.get("X-Real-IP", req.client.host if req.client else "unknown")
|
||||
user_agent = req.headers.get("User-Agent", "")
|
||||
|
||||
# Store in database
|
||||
try:
|
||||
await db_pool.execute(
|
||||
"""
|
||||
INSERT INTO urls (short_code, original_url, ip_address, user_agent)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""",
|
||||
short_code,
|
||||
original_url,
|
||||
client_ip,
|
||||
user_agent,
|
||||
)
|
||||
except asyncpg.UniqueViolationError:
|
||||
raise HTTPException(status_code=409, detail="Short code collision. Please retry.")
|
||||
|
||||
# Cache in Redis
|
||||
await redis_client.setex(f"url:{short_code}", CACHE_TTL, original_url)
|
||||
|
||||
return ShortenResponse(
|
||||
short_url=f"{BASE_URL}/{short_code}",
|
||||
short_code=short_code,
|
||||
original_url=original_url,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/{short_code}")
|
||||
async def redirect_to_url(short_code: str, req: Request):
|
||||
"""
|
||||
Redirect to the original URL.
|
||||
|
||||
Process:
|
||||
1. Check Redis cache first (fast path)
|
||||
2. If cache miss, query PostgreSQL
|
||||
3. Update cache on miss
|
||||
4. Track click asynchronously (fire and forget)
|
||||
"""
|
||||
# Try cache first
|
||||
original_url = await redis_client.get(f"url:{short_code}")
|
||||
|
||||
if not original_url:
|
||||
# Cache miss - query database
|
||||
row = await db_pool.fetchrow(
|
||||
"SELECT original_url FROM urls WHERE short_code = $1",
|
||||
short_code,
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
original_url = row["original_url"]
|
||||
|
||||
# Populate cache
|
||||
await redis_client.setex(f"url:{short_code}", CACHE_TTL, original_url)
|
||||
|
||||
# Track click asynchronously (don't slow down redirect)
|
||||
asyncio.create_task(
|
||||
track_click(
|
||||
short_code,
|
||||
req.headers.get("X-Real-IP", req.client.host if req.client else None),
|
||||
req.headers.get("User-Agent"),
|
||||
req.headers.get("Referer"),
|
||||
)
|
||||
)
|
||||
|
||||
# 301 = permanent redirect (cacheable by browsers)
|
||||
# 302 = temporary redirect (not cached, better for analytics)
|
||||
return RedirectResponse(url=original_url, status_code=302)
|
||||
|
||||
|
||||
@app.get("/stats/{short_code}", response_model=StatsResponse)
|
||||
async def get_stats(short_code: str):
|
||||
"""Get statistics for a short URL."""
|
||||
row = await db_pool.fetchrow(
|
||||
"""
|
||||
SELECT short_code, original_url, click_count, created_at
|
||||
FROM urls WHERE short_code = $1
|
||||
""",
|
||||
short_code,
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||
|
||||
return StatsResponse(
|
||||
short_code=row["short_code"],
|
||||
original_url=row["original_url"],
|
||||
click_count=row["click_count"],
|
||||
created_at=row["created_at"].isoformat(),
|
||||
)
|
||||
|
||||
|
||||
async def track_click(
|
||||
short_code: str,
|
||||
ip_address: str | None,
|
||||
user_agent: str | None,
|
||||
referer: str | None,
|
||||
):
|
||||
"""
|
||||
Track a click event asynchronously.
|
||||
|
||||
This runs in the background after the redirect is sent,
|
||||
so it doesn't slow down the user experience.
|
||||
"""
|
||||
try:
|
||||
# Increment click count
|
||||
await db_pool.execute(
|
||||
"UPDATE urls SET click_count = click_count + 1 WHERE short_code = $1",
|
||||
short_code,
|
||||
)
|
||||
|
||||
# Store detailed click record
|
||||
await db_pool.execute(
|
||||
"""
|
||||
INSERT INTO clicks (short_code, ip_address, user_agent, referer)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""",
|
||||
short_code,
|
||||
ip_address,
|
||||
user_agent,
|
||||
referer,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail - analytics shouldn't break redirects
|
||||
print(f"[Warning] Failed to track click: {e}")
|
||||
142
app/snowflake.py
Normal file
142
app/snowflake.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""
|
||||
Snowflake ID Generator for distributed unique ID generation.
|
||||
|
||||
Structure (64 bits total):
|
||||
- 1 bit: sign (always 0)
|
||||
- 41 bits: timestamp in milliseconds (69 years from epoch)
|
||||
- 10 bits: machine/worker ID (1024 unique machines)
|
||||
- 12 bits: sequence number (4096 IDs per millisecond per machine)
|
||||
|
||||
Benefits:
|
||||
- No coordination needed between machines
|
||||
- Time-sortable (IDs are roughly ordered by creation time)
|
||||
- Guaranteed unique across distributed system
|
||||
- High throughput: 4096 IDs/ms/machine = 4M IDs/second/machine
|
||||
|
||||
Used by: Twitter, Discord, Instagram (with variations)
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
|
||||
|
||||
class SnowflakeGenerator:
|
||||
# Custom epoch: Jan 1, 2024 00:00:00 UTC (extends usable time range)
|
||||
EPOCH = 1704067200000 # milliseconds
|
||||
|
||||
# Bit lengths
|
||||
TIMESTAMP_BITS = 41
|
||||
MACHINE_ID_BITS = 10
|
||||
SEQUENCE_BITS = 12
|
||||
|
||||
# Max values
|
||||
MAX_MACHINE_ID = (1 << MACHINE_ID_BITS) - 1 # 1023
|
||||
MAX_SEQUENCE = (1 << SEQUENCE_BITS) - 1 # 4095
|
||||
|
||||
# Bit shifts
|
||||
TIMESTAMP_SHIFT = MACHINE_ID_BITS + SEQUENCE_BITS # 22
|
||||
MACHINE_ID_SHIFT = SEQUENCE_BITS # 12
|
||||
|
||||
def __init__(self, machine_id: int):
|
||||
"""
|
||||
Initialize generator with a unique machine ID.
|
||||
|
||||
Args:
|
||||
machine_id: Unique identifier for this machine/worker (0-1023)
|
||||
"""
|
||||
if not 0 <= machine_id <= self.MAX_MACHINE_ID:
|
||||
raise ValueError(f"machine_id must be between 0 and {self.MAX_MACHINE_ID}")
|
||||
|
||||
self.machine_id = machine_id
|
||||
self.sequence = 0
|
||||
self.last_timestamp = -1
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _current_timestamp(self) -> int:
|
||||
"""Get current time in milliseconds since our epoch."""
|
||||
return int(time.time() * 1000) - self.EPOCH
|
||||
|
||||
def _wait_next_millis(self, last_timestamp: int) -> int:
|
||||
"""Block until next millisecond if we've exhausted sequence."""
|
||||
timestamp = self._current_timestamp()
|
||||
while timestamp <= last_timestamp:
|
||||
time.sleep(0.0001) # 0.1ms
|
||||
timestamp = self._current_timestamp()
|
||||
return timestamp
|
||||
|
||||
def generate(self) -> int:
|
||||
"""
|
||||
Generate a unique Snowflake ID.
|
||||
|
||||
Thread-safe: Can be called from multiple threads.
|
||||
|
||||
Returns:
|
||||
64-bit unique ID
|
||||
"""
|
||||
with self._lock:
|
||||
timestamp = self._current_timestamp()
|
||||
|
||||
if timestamp < self.last_timestamp:
|
||||
# Clock moved backwards - this is problematic
|
||||
raise RuntimeError(
|
||||
f"Clock moved backwards. Refusing to generate ID. "
|
||||
f"Last: {self.last_timestamp}, Current: {timestamp}"
|
||||
)
|
||||
|
||||
if timestamp == self.last_timestamp:
|
||||
# Same millisecond - increment sequence
|
||||
self.sequence = (self.sequence + 1) & self.MAX_SEQUENCE
|
||||
|
||||
if self.sequence == 0:
|
||||
# Sequence exhausted - wait for next millisecond
|
||||
timestamp = self._wait_next_millis(self.last_timestamp)
|
||||
else:
|
||||
# New millisecond - reset sequence
|
||||
self.sequence = 0
|
||||
|
||||
self.last_timestamp = timestamp
|
||||
|
||||
# Compose the ID
|
||||
snowflake_id = (
|
||||
(timestamp << self.TIMESTAMP_SHIFT)
|
||||
| (self.machine_id << self.MACHINE_ID_SHIFT)
|
||||
| self.sequence
|
||||
)
|
||||
|
||||
return snowflake_id
|
||||
|
||||
def parse(self, snowflake_id: int) -> dict:
|
||||
"""
|
||||
Parse a Snowflake ID into its components.
|
||||
|
||||
Useful for debugging and understanding ID generation.
|
||||
"""
|
||||
timestamp = (snowflake_id >> self.TIMESTAMP_SHIFT) + self.EPOCH
|
||||
machine_id = (snowflake_id >> self.MACHINE_ID_SHIFT) & self.MAX_MACHINE_ID
|
||||
sequence = snowflake_id & self.MAX_SEQUENCE
|
||||
|
||||
return {
|
||||
"timestamp_ms": timestamp,
|
||||
"timestamp_iso": time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S", time.gmtime(timestamp / 1000)
|
||||
),
|
||||
"machine_id": machine_id,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
|
||||
# Module-level generator (initialized in main.py)
|
||||
_generator: SnowflakeGenerator | None = None
|
||||
|
||||
|
||||
def init_generator(machine_id: int) -> None:
|
||||
"""Initialize the global Snowflake generator."""
|
||||
global _generator
|
||||
_generator = SnowflakeGenerator(machine_id)
|
||||
|
||||
|
||||
def generate_id() -> int:
|
||||
"""Generate a unique Snowflake ID using the global generator."""
|
||||
if _generator is None:
|
||||
raise RuntimeError("Snowflake generator not initialized. Call init_generator() first.")
|
||||
return _generator.generate()
|
||||
55
nginx/nginx.conf
Normal file
55
nginx/nginx.conf
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
events {
|
||||
worker_connections 1024;
|
||||
}
|
||||
|
||||
http {
|
||||
# Upstream API servers - Docker will resolve 'api' to all instances
|
||||
upstream api_servers {
|
||||
# Load balancing method: least_conn sends to least busy server
|
||||
# Other options: round_robin (default), ip_hash, random
|
||||
least_conn;
|
||||
|
||||
# Docker Compose service discovery
|
||||
server api:8000;
|
||||
|
||||
# When scaling manually, you'd list servers like:
|
||||
# server api_1:8000;
|
||||
# server api_2:8000;
|
||||
# server api_3:8000;
|
||||
}
|
||||
|
||||
# Rate limiting zone: 10 requests per second per IP
|
||||
limit_req_zone $binary_remote_addr zone=api_limit:10m rate=10r/s;
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
|
||||
# Health check endpoint
|
||||
location /health {
|
||||
access_log off;
|
||||
return 200 "OK\n";
|
||||
add_header Content-Type text/plain;
|
||||
}
|
||||
|
||||
# API endpoints
|
||||
location / {
|
||||
# Apply rate limiting with burst
|
||||
limit_req zone=api_limit burst=20 nodelay;
|
||||
|
||||
proxy_pass http://api_servers;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# Pass client info to backend
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# Timeouts
|
||||
proxy_connect_timeout 5s;
|
||||
proxy_send_timeout 10s;
|
||||
proxy_read_timeout 10s;
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue