This commit is contained in:
Dobromir Popov
2025-08-04 18:38:51 +03:00
parent ff75af566c
commit fd6ec4eb40
6 changed files with 1318 additions and 0 deletions

View File

@ -53,6 +53,9 @@
- Implement data processor for normalizing raw exchange data
- Create validation logic for order book and trade data
- Implement data quality checks and filtering
- Add metrics calculation for order book statistics
- Write comprehensive unit tests for data processing logic
- _Requirements: 1.4, 6.3, 8.1_

15
COBY/api/__init__.py Normal file
View File

@ -0,0 +1,15 @@
"""
API layer for the COBY system.
"""
from .rest_api import create_app
from .websocket_server import WebSocketServer
from .rate_limiter import RateLimiter
from .response_formatter import ResponseFormatter
__all__ = [
'create_app',
'WebSocketServer',
'RateLimiter',
'ResponseFormatter'
]

183
COBY/api/rate_limiter.py Normal file
View File

@ -0,0 +1,183 @@
"""
Rate limiting for API endpoints.
"""
import time
from typing import Dict, Optional
from collections import defaultdict, deque
from ..utils.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""
Token bucket rate limiter for API endpoints.
Provides per-client rate limiting with configurable limits.
"""
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests per minute
burst_size: Maximum burst requests
"""
self.requests_per_minute = requests_per_minute
self.burst_size = burst_size
self.refill_rate = requests_per_minute / 60.0 # tokens per second
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
'tokens': float(burst_size),
'last_refill': time.time()
})
# Request history for monitoring
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
"""
Check if request is allowed for client.
Args:
client_id: Client identifier (IP, user ID, etc.)
tokens_requested: Number of tokens requested
Returns:
bool: True if request is allowed, False otherwise
"""
current_time = time.time()
bucket = self.buckets[client_id]
# Refill tokens based on time elapsed
time_elapsed = current_time - bucket['last_refill']
tokens_to_add = time_elapsed * self.refill_rate
# Update bucket
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
bucket['last_refill'] = current_time
# Check if enough tokens available
if bucket['tokens'] >= tokens_requested:
bucket['tokens'] -= tokens_requested
# Record successful request
self.request_history[client_id].append(current_time)
return True
else:
logger.debug(f"Rate limit exceeded for client {client_id}")
return False
def get_remaining_tokens(self, client_id: str) -> float:
"""
Get remaining tokens for client.
Args:
client_id: Client identifier
Returns:
float: Number of remaining tokens
"""
current_time = time.time()
bucket = self.buckets[client_id]
# Calculate current tokens (with refill)
time_elapsed = current_time - bucket['last_refill']
tokens_to_add = time_elapsed * self.refill_rate
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
return current_tokens
def get_reset_time(self, client_id: str) -> float:
"""
Get time until bucket is fully refilled.
Args:
client_id: Client identifier
Returns:
float: Seconds until full refill
"""
remaining_tokens = self.get_remaining_tokens(client_id)
tokens_needed = self.burst_size - remaining_tokens
if tokens_needed <= 0:
return 0.0
return tokens_needed / self.refill_rate
def get_client_stats(self, client_id: str) -> Dict[str, float]:
"""
Get statistics for a client.
Args:
client_id: Client identifier
Returns:
Dict: Client statistics
"""
current_time = time.time()
history = self.request_history[client_id]
# Count requests in last minute
minute_ago = current_time - 60
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
return {
'remaining_tokens': self.get_remaining_tokens(client_id),
'reset_time': self.get_reset_time(client_id),
'requests_last_minute': recent_requests,
'total_requests': len(history)
}
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
"""
Clean up old client data.
Args:
max_age_hours: Maximum age of data to keep
"""
current_time = time.time()
cutoff_time = current_time - (max_age_hours * 3600)
# Clean up buckets for inactive clients
inactive_clients = []
for client_id, bucket in self.buckets.items():
if bucket['last_refill'] < cutoff_time:
inactive_clients.append(client_id)
for client_id in inactive_clients:
del self.buckets[client_id]
if client_id in self.request_history:
del self.request_history[client_id]
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
def get_global_stats(self) -> Dict[str, int]:
"""Get global rate limiter statistics"""
current_time = time.time()
minute_ago = current_time - 60
total_clients = len(self.buckets)
active_clients = 0
total_requests_last_minute = 0
for client_id, history in self.request_history.items():
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
if recent_requests > 0:
active_clients += 1
total_requests_last_minute += recent_requests
return {
'total_clients': total_clients,
'active_clients': active_clients,
'requests_per_minute_limit': self.requests_per_minute,
'burst_size': self.burst_size,
'total_requests_last_minute': total_requests_last_minute
}

View File

@ -0,0 +1,326 @@
"""
Response formatting for API endpoints.
"""
import json
from typing import Any, Dict, Optional, List
from datetime import datetime
from ..utils.logging import get_logger
from ..utils.timing import get_current_timestamp
logger = get_logger(__name__)
class ResponseFormatter:
"""
Formats API responses with consistent structure and metadata.
"""
def __init__(self):
"""Initialize response formatter"""
self.responses_formatted = 0
logger.info("Response formatter initialized")
def success(self, data: Any, message: str = "Success",
metadata: Optional[Dict] = None) -> Dict[str, Any]:
"""
Format successful response.
Args:
data: Response data
message: Success message
metadata: Additional metadata
Returns:
Dict: Formatted response
"""
response = {
'success': True,
'message': message,
'data': data,
'timestamp': get_current_timestamp().isoformat(),
'metadata': metadata or {}
}
self.responses_formatted += 1
return response
def error(self, message: str, error_code: str = "UNKNOWN_ERROR",
details: Optional[Dict] = None, status_code: int = 400) -> Dict[str, Any]:
"""
Format error response.
Args:
message: Error message
error_code: Error code
details: Error details
status_code: HTTP status code
Returns:
Dict: Formatted error response
"""
response = {
'success': False,
'error': {
'message': message,
'code': error_code,
'details': details or {},
'status_code': status_code
},
'timestamp': get_current_timestamp().isoformat()
}
self.responses_formatted += 1
return response
def paginated(self, data: List[Any], page: int, page_size: int,
total_count: int, message: str = "Success") -> Dict[str, Any]:
"""
Format paginated response.
Args:
data: Page data
page: Current page number
page_size: Items per page
total_count: Total number of items
message: Success message
Returns:
Dict: Formatted paginated response
"""
total_pages = (total_count + page_size - 1) // page_size
pagination = {
'page': page,
'page_size': page_size,
'total_count': total_count,
'total_pages': total_pages,
'has_next': page < total_pages,
'has_previous': page > 1
}
return self.success(
data=data,
message=message,
metadata={'pagination': pagination}
)
def heatmap_response(self, heatmap_data, symbol: str,
exchange: Optional[str] = None) -> Dict[str, Any]:
"""
Format heatmap data response.
Args:
heatmap_data: Heatmap data
symbol: Trading symbol
exchange: Exchange name (None for consolidated)
Returns:
Dict: Formatted heatmap response
"""
if not heatmap_data:
return self.error("Heatmap data not found", "HEATMAP_NOT_FOUND", status_code=404)
# Convert heatmap to API format
formatted_data = {
'symbol': heatmap_data.symbol,
'timestamp': heatmap_data.timestamp.isoformat(),
'bucket_size': heatmap_data.bucket_size,
'exchange': exchange,
'points': [
{
'price': point.price,
'volume': point.volume,
'intensity': point.intensity,
'side': point.side
}
for point in heatmap_data.data
]
}
metadata = {
'total_points': len(heatmap_data.data),
'bid_points': len([p for p in heatmap_data.data if p.side == 'bid']),
'ask_points': len([p for p in heatmap_data.data if p.side == 'ask']),
'data_type': 'consolidated' if not exchange else 'exchange_specific'
}
return self.success(
data=formatted_data,
message=f"Heatmap data for {symbol}",
metadata=metadata
)
def orderbook_response(self, orderbook_data, symbol: str, exchange: str) -> Dict[str, Any]:
"""
Format order book response.
Args:
orderbook_data: Order book data
symbol: Trading symbol
exchange: Exchange name
Returns:
Dict: Formatted order book response
"""
if not orderbook_data:
return self.error("Order book not found", "ORDERBOOK_NOT_FOUND", status_code=404)
# Convert order book to API format
formatted_data = {
'symbol': orderbook_data.symbol,
'exchange': orderbook_data.exchange,
'timestamp': orderbook_data.timestamp.isoformat(),
'sequence_id': orderbook_data.sequence_id,
'bids': [
{
'price': bid.price,
'size': bid.size,
'count': bid.count
}
for bid in orderbook_data.bids
],
'asks': [
{
'price': ask.price,
'size': ask.size,
'count': ask.count
}
for ask in orderbook_data.asks
],
'mid_price': orderbook_data.mid_price,
'spread': orderbook_data.spread,
'bid_volume': orderbook_data.bid_volume,
'ask_volume': orderbook_data.ask_volume
}
metadata = {
'bid_levels': len(orderbook_data.bids),
'ask_levels': len(orderbook_data.asks),
'total_bid_volume': orderbook_data.bid_volume,
'total_ask_volume': orderbook_data.ask_volume
}
return self.success(
data=formatted_data,
message=f"Order book for {symbol}@{exchange}",
metadata=metadata
)
def metrics_response(self, metrics_data, symbol: str, exchange: str) -> Dict[str, Any]:
"""
Format metrics response.
Args:
metrics_data: Metrics data
symbol: Trading symbol
exchange: Exchange name
Returns:
Dict: Formatted metrics response
"""
if not metrics_data:
return self.error("Metrics not found", "METRICS_NOT_FOUND", status_code=404)
# Convert metrics to API format
formatted_data = {
'symbol': metrics_data.symbol,
'exchange': metrics_data.exchange,
'timestamp': metrics_data.timestamp.isoformat(),
'mid_price': metrics_data.mid_price,
'spread': metrics_data.spread,
'spread_percentage': metrics_data.spread_percentage,
'bid_volume': metrics_data.bid_volume,
'ask_volume': metrics_data.ask_volume,
'volume_imbalance': metrics_data.volume_imbalance,
'depth_10': metrics_data.depth_10,
'depth_50': metrics_data.depth_50
}
return self.success(
data=formatted_data,
message=f"Metrics for {symbol}@{exchange}"
)
def status_response(self, status_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Format system status response.
Args:
status_data: System status data
Returns:
Dict: Formatted status response
"""
return self.success(
data=status_data,
message="System status",
metadata={'response_count': self.responses_formatted}
)
def rate_limit_error(self, client_stats: Dict[str, float]) -> Dict[str, Any]:
"""
Format rate limit error response.
Args:
client_stats: Client rate limit statistics
Returns:
Dict: Formatted rate limit error
"""
return self.error(
message="Rate limit exceeded",
error_code="RATE_LIMIT_EXCEEDED",
details={
'remaining_tokens': client_stats['remaining_tokens'],
'reset_time': client_stats['reset_time'],
'requests_last_minute': client_stats['requests_last_minute']
},
status_code=429
)
def validation_error(self, field: str, message: str) -> Dict[str, Any]:
"""
Format validation error response.
Args:
field: Field that failed validation
message: Validation error message
Returns:
Dict: Formatted validation error
"""
return self.error(
message=f"Validation error: {message}",
error_code="VALIDATION_ERROR",
details={'field': field, 'message': message},
status_code=400
)
def to_json(self, response: Dict[str, Any], indent: Optional[int] = None) -> str:
"""
Convert response to JSON string.
Args:
response: Response dictionary
indent: JSON indentation (None for compact)
Returns:
str: JSON string
"""
try:
return json.dumps(response, indent=indent, ensure_ascii=False, default=str)
except Exception as e:
logger.error(f"Error converting response to JSON: {e}")
return json.dumps(self.error("JSON serialization failed", "JSON_ERROR"))
def get_stats(self) -> Dict[str, int]:
"""Get formatter statistics"""
return {
'responses_formatted': self.responses_formatted
}
def reset_stats(self) -> None:
"""Reset formatter statistics"""
self.responses_formatted = 0
logger.info("Response formatter statistics reset")

391
COBY/api/rest_api.py Normal file
View File

@ -0,0 +1,391 @@
"""
REST API server for COBY system.
"""
from fastapi import FastAPI, HTTPException, Request, Query, Path
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import Optional, List
import asyncio
from ..config import config
from ..caching.redis_manager import redis_manager
from ..utils.logging import get_logger, set_correlation_id
from ..utils.validation import validate_symbol
from .rate_limiter import RateLimiter
from .response_formatter import ResponseFormatter
logger = get_logger(__name__)
def create_app() -> FastAPI:
"""Create and configure FastAPI application"""
app = FastAPI(
title="COBY Market Data API",
description="Real-time cryptocurrency market data aggregation API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=config.api.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
# Initialize components
rate_limiter = RateLimiter(
requests_per_minute=config.api.rate_limit,
burst_size=20
)
response_formatter = ResponseFormatter()
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
"""Rate limiting middleware"""
client_ip = request.client.host
if not rate_limiter.is_allowed(client_ip):
client_stats = rate_limiter.get_client_stats(client_ip)
error_response = response_formatter.rate_limit_error(client_stats)
return JSONResponse(
status_code=429,
content=error_response,
headers={
"X-RateLimit-Remaining": str(int(client_stats['remaining_tokens'])),
"X-RateLimit-Reset": str(int(client_stats['reset_time']))
}
)
response = await call_next(request)
# Add rate limit headers
client_stats = rate_limiter.get_client_stats(client_ip)
response.headers["X-RateLimit-Remaining"] = str(int(client_stats['remaining_tokens']))
response.headers["X-RateLimit-Reset"] = str(int(client_stats['reset_time']))
return response
@app.middleware("http")
async def correlation_middleware(request: Request, call_next):
"""Add correlation ID to requests"""
set_correlation_id()
response = await call_next(request)
return response
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup"""
try:
await redis_manager.initialize()
logger.info("API server startup completed")
except Exception as e:
logger.error(f"API server startup failed: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
try:
await redis_manager.close()
logger.info("API server shutdown completed")
except Exception as e:
logger.error(f"API server shutdown error: {e}")
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Check Redis connection
redis_healthy = await redis_manager.ping()
health_data = {
'status': 'healthy' if redis_healthy else 'degraded',
'redis': 'connected' if redis_healthy else 'disconnected',
'version': '1.0.0'
}
return response_formatter.status_response(health_data)
except Exception as e:
logger.error(f"Health check failed: {e}")
return JSONResponse(
status_code=503,
content=response_formatter.error("Service unavailable", "HEALTH_CHECK_FAILED")
)
# Heatmap endpoints
@app.get("/api/v1/heatmap/{symbol}")
async def get_heatmap(
symbol: str = Path(..., description="Trading symbol (e.g., BTCUSDT)"),
exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)")
):
"""Get heatmap data for a symbol"""
try:
# Validate symbol
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbol", "Invalid symbol format")
)
# Get heatmap from cache
heatmap_data = await redis_manager.get_heatmap(symbol.upper(), exchange)
return response_formatter.heatmap_response(heatmap_data, symbol.upper(), exchange)
except Exception as e:
logger.error(f"Error getting heatmap for {symbol}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "HEATMAP_ERROR")
)
# Order book endpoints
@app.get("/api/v1/orderbook/{symbol}/{exchange}")
async def get_orderbook(
symbol: str = Path(..., description="Trading symbol"),
exchange: str = Path(..., description="Exchange name")
):
"""Get order book data for a symbol on an exchange"""
try:
# Validate symbol
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbol", "Invalid symbol format")
)
# Get order book from cache
orderbook_data = await redis_manager.get_orderbook(symbol.upper(), exchange.lower())
return response_formatter.orderbook_response(orderbook_data, symbol.upper(), exchange.lower())
except Exception as e:
logger.error(f"Error getting order book for {symbol}@{exchange}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "ORDERBOOK_ERROR")
)
# Metrics endpoints
@app.get("/api/v1/metrics/{symbol}/{exchange}")
async def get_metrics(
symbol: str = Path(..., description="Trading symbol"),
exchange: str = Path(..., description="Exchange name")
):
"""Get metrics data for a symbol on an exchange"""
try:
# Validate symbol
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbol", "Invalid symbol format")
)
# Get metrics from cache
metrics_data = await redis_manager.get_metrics(symbol.upper(), exchange.lower())
return response_formatter.metrics_response(metrics_data, symbol.upper(), exchange.lower())
except Exception as e:
logger.error(f"Error getting metrics for {symbol}@{exchange}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "METRICS_ERROR")
)
# Exchange status endpoints
@app.get("/api/v1/status/{exchange}")
async def get_exchange_status(
exchange: str = Path(..., description="Exchange name")
):
"""Get status for an exchange"""
try:
# Get status from cache
status_data = await redis_manager.get_exchange_status(exchange.lower())
if not status_data:
return JSONResponse(
status_code=404,
content=response_formatter.error("Exchange status not found", "STATUS_NOT_FOUND")
)
return response_formatter.success(
data=status_data,
message=f"Status for {exchange}"
)
except Exception as e:
logger.error(f"Error getting status for {exchange}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "STATUS_ERROR")
)
# List endpoints
@app.get("/api/v1/symbols")
async def list_symbols():
"""List available trading symbols"""
try:
# Get symbols from cache (this would be populated by exchange connectors)
symbols_pattern = "symbols:*"
symbol_keys = await redis_manager.keys(symbols_pattern)
all_symbols = set()
for key in symbol_keys:
symbols_data = await redis_manager.get(key)
if symbols_data and isinstance(symbols_data, list):
all_symbols.update(symbols_data)
return response_formatter.success(
data=sorted(list(all_symbols)),
message="Available trading symbols",
metadata={'total_symbols': len(all_symbols)}
)
except Exception as e:
logger.error(f"Error listing symbols: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "SYMBOLS_ERROR")
)
@app.get("/api/v1/exchanges")
async def list_exchanges():
"""List available exchanges"""
try:
# Get exchange status keys
status_pattern = "st:*"
status_keys = await redis_manager.keys(status_pattern)
exchanges = []
for key in status_keys:
# Extract exchange name from key (st:exchange_name)
exchange_name = key.split(':', 1)[1] if ':' in key else key
exchanges.append(exchange_name)
return response_formatter.success(
data=sorted(exchanges),
message="Available exchanges",
metadata={'total_exchanges': len(exchanges)}
)
except Exception as e:
logger.error(f"Error listing exchanges: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "EXCHANGES_ERROR")
)
# Statistics endpoints
@app.get("/api/v1/stats/cache")
async def get_cache_stats():
"""Get cache statistics"""
try:
cache_stats = redis_manager.get_stats()
redis_health = await redis_manager.health_check()
stats_data = {
'cache_performance': cache_stats,
'redis_health': redis_health
}
return response_formatter.success(
data=stats_data,
message="Cache statistics"
)
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "STATS_ERROR")
)
@app.get("/api/v1/stats/api")
async def get_api_stats():
"""Get API statistics"""
try:
api_stats = {
'rate_limiter': rate_limiter.get_global_stats(),
'response_formatter': response_formatter.get_stats()
}
return response_formatter.success(
data=api_stats,
message="API statistics"
)
except Exception as e:
logger.error(f"Error getting API stats: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "API_STATS_ERROR")
)
# Batch endpoints for efficiency
@app.get("/api/v1/batch/heatmaps")
async def get_batch_heatmaps(
symbols: str = Query(..., description="Comma-separated list of symbols"),
exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)")
):
"""Get heatmaps for multiple symbols"""
try:
symbol_list = [s.strip().upper() for s in symbols.split(',')]
# Validate all symbols
for symbol in symbol_list:
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbols", f"Invalid symbol: {symbol}")
)
# Get heatmaps in batch
heatmaps = {}
for symbol in symbol_list:
heatmap_data = await redis_manager.get_heatmap(symbol, exchange)
if heatmap_data:
heatmaps[symbol] = {
'symbol': heatmap_data.symbol,
'timestamp': heatmap_data.timestamp.isoformat(),
'bucket_size': heatmap_data.bucket_size,
'points': [
{
'price': point.price,
'volume': point.volume,
'intensity': point.intensity,
'side': point.side
}
for point in heatmap_data.data
]
}
return response_formatter.success(
data=heatmaps,
message=f"Batch heatmaps for {len(symbol_list)} symbols",
metadata={
'requested_symbols': len(symbol_list),
'found_heatmaps': len(heatmaps),
'exchange': exchange or 'consolidated'
}
)
except Exception as e:
logger.error(f"Error getting batch heatmaps: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "BATCH_HEATMAPS_ERROR")
)
return app
# Create the FastAPI app instance
app = create_app()

View File

@ -0,0 +1,400 @@
"""
WebSocket server for real-time data streaming.
"""
import asyncio
import json
from typing import Dict, Set, Optional, Any
from fastapi import WebSocket, WebSocketDisconnect
from ..utils.logging import get_logger, set_correlation_id
from ..utils.validation import validate_symbol
from ..caching.redis_manager import redis_manager
from .response_formatter import ResponseFormatter
logger = get_logger(__name__)
class WebSocketManager:
"""
Manages WebSocket connections and real-time data streaming.
"""
def __init__(self):
"""Initialize WebSocket manager"""
# Active connections: connection_id -> WebSocket
self.connections: Dict[str, WebSocket] = {}
# Subscriptions: symbol -> set of connection_ids
self.subscriptions: Dict[str, Set[str]] = {}
# Connection metadata: connection_id -> metadata
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
self.response_formatter = ResponseFormatter()
self.connection_counter = 0
logger.info("WebSocket manager initialized")
async def connect(self, websocket: WebSocket, client_ip: str) -> str:
"""
Accept new WebSocket connection.
Args:
websocket: WebSocket connection
client_ip: Client IP address
Returns:
str: Connection ID
"""
await websocket.accept()
# Generate connection ID
self.connection_counter += 1
connection_id = f"ws_{self.connection_counter}_{client_ip}"
# Store connection
self.connections[connection_id] = websocket
self.connection_metadata[connection_id] = {
'client_ip': client_ip,
'connected_at': asyncio.get_event_loop().time(),
'subscriptions': set(),
'messages_sent': 0
}
logger.info(f"WebSocket connected: {connection_id}")
# Send welcome message
welcome_msg = self.response_formatter.success(
data={'connection_id': connection_id},
message="WebSocket connected successfully"
)
await self._send_to_connection(connection_id, welcome_msg)
return connection_id
async def disconnect(self, connection_id: str) -> None:
"""
Handle WebSocket disconnection.
Args:
connection_id: Connection ID to disconnect
"""
if connection_id in self.connections:
# Remove from all subscriptions
metadata = self.connection_metadata.get(connection_id, {})
for symbol in metadata.get('subscriptions', set()):
await self._unsubscribe_connection(connection_id, symbol)
# Remove connection
del self.connections[connection_id]
del self.connection_metadata[connection_id]
logger.info(f"WebSocket disconnected: {connection_id}")
async def subscribe(self, connection_id: str, symbol: str,
data_type: str = "heatmap") -> bool:
"""
Subscribe connection to symbol updates.
Args:
connection_id: Connection ID
symbol: Trading symbol
data_type: Type of data to subscribe to
Returns:
bool: True if subscribed successfully
"""
try:
# Validate symbol
if not validate_symbol(symbol):
error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format")
await self._send_to_connection(connection_id, error_msg)
return False
symbol = symbol.upper()
subscription_key = f"{symbol}:{data_type}"
# Add to subscriptions
if subscription_key not in self.subscriptions:
self.subscriptions[subscription_key] = set()
self.subscriptions[subscription_key].add(connection_id)
# Update connection metadata
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['subscriptions'].add(subscription_key)
logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}")
# Send confirmation
confirm_msg = self.response_formatter.success(
data={'symbol': symbol, 'data_type': data_type},
message=f"Subscribed to {symbol} {data_type} updates"
)
await self._send_to_connection(connection_id, confirm_msg)
# Send initial data if available
await self._send_initial_data(connection_id, symbol, data_type)
return True
except Exception as e:
logger.error(f"Error subscribing {connection_id} to {symbol}: {e}")
error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR")
await self._send_to_connection(connection_id, error_msg)
return False
async def unsubscribe(self, connection_id: str, symbol: str,
data_type: str = "heatmap") -> bool:
"""
Unsubscribe connection from symbol updates.
Args:
connection_id: Connection ID
symbol: Trading symbol
data_type: Type of data to unsubscribe from
Returns:
bool: True if unsubscribed successfully
"""
try:
symbol = symbol.upper()
subscription_key = f"{symbol}:{data_type}"
await self._unsubscribe_connection(connection_id, subscription_key)
# Send confirmation
confirm_msg = self.response_formatter.success(
data={'symbol': symbol, 'data_type': data_type},
message=f"Unsubscribed from {symbol} {data_type} updates"
)
await self._send_to_connection(connection_id, confirm_msg)
return True
except Exception as e:
logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}")
return False
async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int:
"""
Broadcast data update to all subscribers.
Args:
symbol: Trading symbol
data_type: Type of data
data: Data to broadcast
Returns:
int: Number of connections notified
"""
try:
set_correlation_id()
subscription_key = f"{symbol.upper()}:{data_type}"
subscribers = self.subscriptions.get(subscription_key, set())
if not subscribers:
return 0
# Format message based on data type
if data_type == "heatmap":
message = self.response_formatter.heatmap_response(data, symbol)
elif data_type == "orderbook":
message = self.response_formatter.orderbook_response(data, symbol, "consolidated")
else:
message = self.response_formatter.success(data, f"{data_type} update for {symbol}")
# Add update type to message
message['update_type'] = data_type
message['symbol'] = symbol
# Send to all subscribers
sent_count = 0
for connection_id in subscribers.copy(): # Copy to avoid modification during iteration
if await self._send_to_connection(connection_id, message):
sent_count += 1
logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections")
return sent_count
except Exception as e:
logger.error(f"Error broadcasting update for {symbol}: {e}")
return 0
async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool:
"""
Send message to specific connection.
Args:
connection_id: Connection ID
message: Message to send
Returns:
bool: True if sent successfully
"""
try:
if connection_id not in self.connections:
return False
websocket = self.connections[connection_id]
message_json = json.dumps(message, default=str)
await websocket.send_text(message_json)
# Update statistics
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['messages_sent'] += 1
return True
except Exception as e:
logger.warning(f"Error sending message to {connection_id}: {e}")
# Remove broken connection
await self.disconnect(connection_id)
return False
async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None:
"""Remove connection from subscription"""
if subscription_key in self.subscriptions:
self.subscriptions[subscription_key].discard(connection_id)
# Clean up empty subscriptions
if not self.subscriptions[subscription_key]:
del self.subscriptions[subscription_key]
# Update connection metadata
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key)
async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None:
"""Send initial data to newly subscribed connection"""
try:
if data_type == "heatmap":
# Get latest heatmap from cache
heatmap_data = await redis_manager.get_heatmap(symbol)
if heatmap_data:
message = self.response_formatter.heatmap_response(heatmap_data, symbol)
message['update_type'] = 'initial_data'
await self._send_to_connection(connection_id, message)
elif data_type == "orderbook":
# Could get latest order book from cache
# This would require knowing which exchange to get data from
pass
except Exception as e:
logger.warning(f"Error sending initial data to {connection_id}: {e}")
def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket manager statistics"""
total_subscriptions = sum(len(subs) for subs in self.subscriptions.values())
return {
'active_connections': len(self.connections),
'total_subscriptions': total_subscriptions,
'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())),
'connection_counter': self.connection_counter
}
# Global WebSocket manager instance
websocket_manager = WebSocketManager()
class WebSocketServer:
"""
WebSocket server for real-time data streaming.
"""
def __init__(self):
"""Initialize WebSocket server"""
self.manager = websocket_manager
logger.info("WebSocket server initialized")
async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None:
"""
Handle WebSocket connection lifecycle.
Args:
websocket: WebSocket connection
client_ip: Client IP address
"""
connection_id = None
try:
# Accept connection
connection_id = await self.manager.connect(websocket, client_ip)
# Handle messages
while True:
try:
# Receive message
message = await websocket.receive_text()
await self._handle_message(connection_id, message)
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected: {connection_id}")
break
except Exception as e:
logger.error(f"WebSocket connection error: {e}")
finally:
# Clean up connection
if connection_id:
await self.manager.disconnect(connection_id)
async def _handle_message(self, connection_id: str, message: str) -> None:
"""
Handle incoming WebSocket message.
Args:
connection_id: Connection ID
message: Received message
"""
try:
# Parse message
data = json.loads(message)
action = data.get('action')
if action == 'subscribe':
symbol = data.get('symbol')
data_type = data.get('data_type', 'heatmap')
await self.manager.subscribe(connection_id, symbol, data_type)
elif action == 'unsubscribe':
symbol = data.get('symbol')
data_type = data.get('data_type', 'heatmap')
await self.manager.unsubscribe(connection_id, symbol, data_type)
elif action == 'ping':
# Send pong response
pong_msg = self.manager.response_formatter.success(
data={'action': 'pong'},
message="Pong"
)
await self.manager._send_to_connection(connection_id, pong_msg)
else:
# Unknown action
error_msg = self.manager.response_formatter.error(
f"Unknown action: {action}",
"UNKNOWN_ACTION"
)
await self.manager._send_to_connection(connection_id, error_msg)
except json.JSONDecodeError:
error_msg = self.manager.response_formatter.error(
"Invalid JSON message",
"INVALID_JSON"
)
await self.manager._send_to_connection(connection_id, error_msg)
except Exception as e:
logger.error(f"Error handling WebSocket message: {e}")
error_msg = self.manager.response_formatter.error(
"Message processing failed",
"MESSAGE_ERROR"
)
await self.manager._send_to_connection(connection_id, error_msg)