122 lines
3.1 KiB
Python
122 lines
3.1 KiB
Python
"""
|
|
Lightweight in-process rate limiting helpers.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from collections import defaultdict, deque
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
from redis.asyncio import Redis
|
|
|
|
from app.core.config import get_settings
|
|
|
|
_lock = threading.Lock()
|
|
_hits: dict[str, deque[float]] = defaultdict(deque)
|
|
_redis_client: Redis | None = None
|
|
_redis_unavailable = False
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _client_ip(request: Request) -> str:
|
|
if request.client and request.client.host:
|
|
return request.client.host
|
|
return "unknown"
|
|
|
|
|
|
def _get_redis_client() -> Redis | None:
|
|
global _redis_client
|
|
if _redis_unavailable:
|
|
return None
|
|
if _redis_client is None:
|
|
settings = get_settings()
|
|
if not settings.REDIS_URL:
|
|
return None
|
|
_redis_client = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
return _redis_client
|
|
|
|
|
|
def _enforce_in_memory_rate_limit(
|
|
*,
|
|
key: str,
|
|
scope: str,
|
|
max_requests: int,
|
|
window_seconds: int,
|
|
) -> None:
|
|
now = time.time()
|
|
cutoff = now - window_seconds
|
|
|
|
with _lock:
|
|
dq = _hits[key]
|
|
while dq and dq[0] <= cutoff:
|
|
dq.popleft()
|
|
if len(dq) >= max_requests:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail=f"Too many requests for {scope}. Please try again later.",
|
|
)
|
|
dq.append(now)
|
|
|
|
|
|
async def enforce_rate_limit(
|
|
request: Request,
|
|
*,
|
|
scope: str,
|
|
max_requests: int,
|
|
window_seconds: int,
|
|
) -> None:
|
|
global _redis_unavailable
|
|
|
|
ip = _client_ip(request)
|
|
key = f"{scope}:{ip}"
|
|
|
|
redis = _get_redis_client()
|
|
if redis is not None:
|
|
try:
|
|
current = await redis.incr(key)
|
|
if current == 1:
|
|
await redis.expire(key, window_seconds)
|
|
if current > max_requests:
|
|
ttl = await redis.ttl(key)
|
|
retry_after = ttl if ttl and ttl > 0 else window_seconds
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail=f"Too many requests for {scope}. Please try again later.",
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
return
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
_redis_unavailable = True
|
|
logger.warning("Redis rate limiter unavailable; falling back to memory: %s", exc)
|
|
|
|
_enforce_in_memory_rate_limit(
|
|
key=key,
|
|
scope=scope,
|
|
max_requests=max_requests,
|
|
window_seconds=window_seconds,
|
|
)
|
|
|
|
|
|
async def close_rate_limit() -> None:
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
return
|
|
try:
|
|
await _redis_client.aclose()
|
|
finally:
|
|
_redis_client = None
|
|
|
|
|
|
def reset_rate_limit_state() -> None:
|
|
"""Reset local limiter state for tests."""
|
|
global _redis_client, _redis_unavailable
|
|
_redis_client = None
|
|
_redis_unavailable = False
|
|
with _lock:
|
|
_hits.clear()
|