183 lines
6.0 KiB
Python
183 lines
6.0 KiB
Python
"""
|
|
Rate limiting for API endpoints.
|
|
"""
|
|
|
|
import time
|
|
from typing import Dict, Optional
|
|
from collections import defaultdict, deque
|
|
from ..utils.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Token bucket rate limiter for API endpoints.
|
|
|
|
Provides per-client rate limiting with configurable limits.
|
|
"""
|
|
|
|
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
|
|
"""
|
|
Initialize rate limiter.
|
|
|
|
Args:
|
|
requests_per_minute: Maximum requests per minute
|
|
burst_size: Maximum burst requests
|
|
"""
|
|
self.requests_per_minute = requests_per_minute
|
|
self.burst_size = burst_size
|
|
self.refill_rate = requests_per_minute / 60.0 # tokens per second
|
|
|
|
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
|
|
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
|
|
'tokens': float(burst_size),
|
|
'last_refill': time.time()
|
|
})
|
|
|
|
# Request history for monitoring
|
|
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
|
|
|
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
|
|
|
|
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
|
|
"""
|
|
Check if request is allowed for client.
|
|
|
|
Args:
|
|
client_id: Client identifier (IP, user ID, etc.)
|
|
tokens_requested: Number of tokens requested
|
|
|
|
Returns:
|
|
bool: True if request is allowed, False otherwise
|
|
"""
|
|
current_time = time.time()
|
|
bucket = self.buckets[client_id]
|
|
|
|
# Refill tokens based on time elapsed
|
|
time_elapsed = current_time - bucket['last_refill']
|
|
tokens_to_add = time_elapsed * self.refill_rate
|
|
|
|
# Update bucket
|
|
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
|
bucket['last_refill'] = current_time
|
|
|
|
# Check if enough tokens available
|
|
if bucket['tokens'] >= tokens_requested:
|
|
bucket['tokens'] -= tokens_requested
|
|
|
|
# Record successful request
|
|
self.request_history[client_id].append(current_time)
|
|
|
|
return True
|
|
else:
|
|
logger.debug(f"Rate limit exceeded for client {client_id}")
|
|
return False
|
|
|
|
def get_remaining_tokens(self, client_id: str) -> float:
|
|
"""
|
|
Get remaining tokens for client.
|
|
|
|
Args:
|
|
client_id: Client identifier
|
|
|
|
Returns:
|
|
float: Number of remaining tokens
|
|
"""
|
|
current_time = time.time()
|
|
bucket = self.buckets[client_id]
|
|
|
|
# Calculate current tokens (with refill)
|
|
time_elapsed = current_time - bucket['last_refill']
|
|
tokens_to_add = time_elapsed * self.refill_rate
|
|
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
|
|
|
return current_tokens
|
|
|
|
def get_reset_time(self, client_id: str) -> float:
|
|
"""
|
|
Get time until bucket is fully refilled.
|
|
|
|
Args:
|
|
client_id: Client identifier
|
|
|
|
Returns:
|
|
float: Seconds until full refill
|
|
"""
|
|
remaining_tokens = self.get_remaining_tokens(client_id)
|
|
tokens_needed = self.burst_size - remaining_tokens
|
|
|
|
if tokens_needed <= 0:
|
|
return 0.0
|
|
|
|
return tokens_needed / self.refill_rate
|
|
|
|
def get_client_stats(self, client_id: str) -> Dict[str, float]:
|
|
"""
|
|
Get statistics for a client.
|
|
|
|
Args:
|
|
client_id: Client identifier
|
|
|
|
Returns:
|
|
Dict: Client statistics
|
|
"""
|
|
current_time = time.time()
|
|
history = self.request_history[client_id]
|
|
|
|
# Count requests in last minute
|
|
minute_ago = current_time - 60
|
|
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
|
|
|
return {
|
|
'remaining_tokens': self.get_remaining_tokens(client_id),
|
|
'reset_time': self.get_reset_time(client_id),
|
|
'requests_last_minute': recent_requests,
|
|
'total_requests': len(history)
|
|
}
|
|
|
|
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
|
|
"""
|
|
Clean up old client data.
|
|
|
|
Args:
|
|
max_age_hours: Maximum age of data to keep
|
|
"""
|
|
current_time = time.time()
|
|
cutoff_time = current_time - (max_age_hours * 3600)
|
|
|
|
# Clean up buckets for inactive clients
|
|
inactive_clients = []
|
|
for client_id, bucket in self.buckets.items():
|
|
if bucket['last_refill'] < cutoff_time:
|
|
inactive_clients.append(client_id)
|
|
|
|
for client_id in inactive_clients:
|
|
del self.buckets[client_id]
|
|
if client_id in self.request_history:
|
|
del self.request_history[client_id]
|
|
|
|
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
|
|
|
|
def get_global_stats(self) -> Dict[str, int]:
|
|
"""Get global rate limiter statistics"""
|
|
current_time = time.time()
|
|
minute_ago = current_time - 60
|
|
|
|
total_clients = len(self.buckets)
|
|
active_clients = 0
|
|
total_requests_last_minute = 0
|
|
|
|
for client_id, history in self.request_history.items():
|
|
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
|
if recent_requests > 0:
|
|
active_clients += 1
|
|
total_requests_last_minute += recent_requests
|
|
|
|
return {
|
|
'total_clients': total_clients,
|
|
'active_clients': active_clients,
|
|
'requests_per_minute_limit': self.requests_per_minute,
|
|
'burst_size': self.burst_size,
|
|
'total_requests_last_minute': total_requests_last_minute
|
|
} |