""" Rate limiting for API endpoints. """ import time from typing import Dict, Optional from collections import defaultdict, deque from ..utils.logging import get_logger logger = get_logger(__name__) class RateLimiter: """ Token bucket rate limiter for API endpoints. Provides per-client rate limiting with configurable limits. """ def __init__(self, requests_per_minute: int = 100, burst_size: int = 20): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests per minute burst_size: Maximum burst requests """ self.requests_per_minute = requests_per_minute self.burst_size = burst_size self.refill_rate = requests_per_minute / 60.0 # tokens per second # Client buckets: client_id -> {'tokens': float, 'last_refill': float} self.buckets: Dict[str, Dict] = defaultdict(lambda: { 'tokens': float(burst_size), 'last_refill': time.time() }) # Request history for monitoring self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000)) logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}") def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool: """ Check if request is allowed for client. Args: client_id: Client identifier (IP, user ID, etc.) tokens_requested: Number of tokens requested Returns: bool: True if request is allowed, False otherwise """ current_time = time.time() bucket = self.buckets[client_id] # Refill tokens based on time elapsed time_elapsed = current_time - bucket['last_refill'] tokens_to_add = time_elapsed * self.refill_rate # Update bucket bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add) bucket['last_refill'] = current_time # Check if enough tokens available if bucket['tokens'] >= tokens_requested: bucket['tokens'] -= tokens_requested # Record successful request self.request_history[client_id].append(current_time) return True else: logger.debug(f"Rate limit exceeded for client {client_id}") return False def get_remaining_tokens(self, client_id: str) -> float: """ Get remaining tokens for client. Args: client_id: Client identifier Returns: float: Number of remaining tokens """ current_time = time.time() bucket = self.buckets[client_id] # Calculate current tokens (with refill) time_elapsed = current_time - bucket['last_refill'] tokens_to_add = time_elapsed * self.refill_rate current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add) return current_tokens def get_reset_time(self, client_id: str) -> float: """ Get time until bucket is fully refilled. Args: client_id: Client identifier Returns: float: Seconds until full refill """ remaining_tokens = self.get_remaining_tokens(client_id) tokens_needed = self.burst_size - remaining_tokens if tokens_needed <= 0: return 0.0 return tokens_needed / self.refill_rate def get_client_stats(self, client_id: str) -> Dict[str, float]: """ Get statistics for a client. Args: client_id: Client identifier Returns: Dict: Client statistics """ current_time = time.time() history = self.request_history[client_id] # Count requests in last minute minute_ago = current_time - 60 recent_requests = sum(1 for req_time in history if req_time > minute_ago) return { 'remaining_tokens': self.get_remaining_tokens(client_id), 'reset_time': self.get_reset_time(client_id), 'requests_last_minute': recent_requests, 'total_requests': len(history) } def cleanup_old_data(self, max_age_hours: int = 24) -> None: """ Clean up old client data. Args: max_age_hours: Maximum age of data to keep """ current_time = time.time() cutoff_time = current_time - (max_age_hours * 3600) # Clean up buckets for inactive clients inactive_clients = [] for client_id, bucket in self.buckets.items(): if bucket['last_refill'] < cutoff_time: inactive_clients.append(client_id) for client_id in inactive_clients: del self.buckets[client_id] if client_id in self.request_history: del self.request_history[client_id] logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients") def get_global_stats(self) -> Dict[str, int]: """Get global rate limiter statistics""" current_time = time.time() minute_ago = current_time - 60 total_clients = len(self.buckets) active_clients = 0 total_requests_last_minute = 0 for client_id, history in self.request_history.items(): recent_requests = sum(1 for req_time in history if req_time > minute_ago) if recent_requests > 0: active_clients += 1 total_requests_last_minute += recent_requests return { 'total_clients': total_clients, 'active_clients': active_clients, 'requests_per_minute_limit': self.requests_per_minute, 'burst_size': self.burst_size, 'total_requests_last_minute': total_requests_last_minute }