api
This commit is contained in:
183
COBY/api/rate_limiter.py
Normal file
183
COBY/api/rate_limiter.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""
|
||||
Rate limiting for API endpoints.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from collections import defaultdict, deque
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Token bucket rate limiter for API endpoints.
|
||||
|
||||
Provides per-client rate limiting with configurable limits.
|
||||
"""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum requests per minute
|
||||
burst_size: Maximum burst requests
|
||||
"""
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.burst_size = burst_size
|
||||
self.refill_rate = requests_per_minute / 60.0 # tokens per second
|
||||
|
||||
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
|
||||
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
|
||||
'tokens': float(burst_size),
|
||||
'last_refill': time.time()
|
||||
})
|
||||
|
||||
# Request history for monitoring
|
||||
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
||||
|
||||
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
|
||||
|
||||
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
|
||||
"""
|
||||
Check if request is allowed for client.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier (IP, user ID, etc.)
|
||||
tokens_requested: Number of tokens requested
|
||||
|
||||
Returns:
|
||||
bool: True if request is allowed, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
bucket = self.buckets[client_id]
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_elapsed = current_time - bucket['last_refill']
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
|
||||
# Update bucket
|
||||
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
||||
bucket['last_refill'] = current_time
|
||||
|
||||
# Check if enough tokens available
|
||||
if bucket['tokens'] >= tokens_requested:
|
||||
bucket['tokens'] -= tokens_requested
|
||||
|
||||
# Record successful request
|
||||
self.request_history[client_id].append(current_time)
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Rate limit exceeded for client {client_id}")
|
||||
return False
|
||||
|
||||
def get_remaining_tokens(self, client_id: str) -> float:
|
||||
"""
|
||||
Get remaining tokens for client.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
float: Number of remaining tokens
|
||||
"""
|
||||
current_time = time.time()
|
||||
bucket = self.buckets[client_id]
|
||||
|
||||
# Calculate current tokens (with refill)
|
||||
time_elapsed = current_time - bucket['last_refill']
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
||||
|
||||
return current_tokens
|
||||
|
||||
def get_reset_time(self, client_id: str) -> float:
|
||||
"""
|
||||
Get time until bucket is fully refilled.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
float: Seconds until full refill
|
||||
"""
|
||||
remaining_tokens = self.get_remaining_tokens(client_id)
|
||||
tokens_needed = self.burst_size - remaining_tokens
|
||||
|
||||
if tokens_needed <= 0:
|
||||
return 0.0
|
||||
|
||||
return tokens_needed / self.refill_rate
|
||||
|
||||
def get_client_stats(self, client_id: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get statistics for a client.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
Dict: Client statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
history = self.request_history[client_id]
|
||||
|
||||
# Count requests in last minute
|
||||
minute_ago = current_time - 60
|
||||
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
||||
|
||||
return {
|
||||
'remaining_tokens': self.get_remaining_tokens(client_id),
|
||||
'reset_time': self.get_reset_time(client_id),
|
||||
'requests_last_minute': recent_requests,
|
||||
'total_requests': len(history)
|
||||
}
|
||||
|
||||
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
|
||||
"""
|
||||
Clean up old client data.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of data to keep
|
||||
"""
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (max_age_hours * 3600)
|
||||
|
||||
# Clean up buckets for inactive clients
|
||||
inactive_clients = []
|
||||
for client_id, bucket in self.buckets.items():
|
||||
if bucket['last_refill'] < cutoff_time:
|
||||
inactive_clients.append(client_id)
|
||||
|
||||
for client_id in inactive_clients:
|
||||
del self.buckets[client_id]
|
||||
if client_id in self.request_history:
|
||||
del self.request_history[client_id]
|
||||
|
||||
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
|
||||
|
||||
def get_global_stats(self) -> Dict[str, int]:
|
||||
"""Get global rate limiter statistics"""
|
||||
current_time = time.time()
|
||||
minute_ago = current_time - 60
|
||||
|
||||
total_clients = len(self.buckets)
|
||||
active_clients = 0
|
||||
total_requests_last_minute = 0
|
||||
|
||||
for client_id, history in self.request_history.items():
|
||||
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
||||
if recent_requests > 0:
|
||||
active_clients += 1
|
||||
total_requests_last_minute += recent_requests
|
||||
|
||||
return {
|
||||
'total_clients': total_clients,
|
||||
'active_clients': active_clients,
|
||||
'requests_per_minute_limit': self.requests_per_minute,
|
||||
'burst_size': self.burst_size,
|
||||
'total_requests_last_minute': total_requests_last_minute
|
||||
}
|
Reference in New Issue
Block a user