""" WebSocket server for real-time data streaming. """ import asyncio import json from typing import Dict, Set, Optional, Any from fastapi import WebSocket, WebSocketDisconnect from ..utils.logging import get_logger, set_correlation_id from ..utils.validation import validate_symbol from ..caching.redis_manager import redis_manager from .response_formatter import ResponseFormatter logger = get_logger(__name__) class WebSocketManager: """ Manages WebSocket connections and real-time data streaming. """ def __init__(self): """Initialize WebSocket manager""" # Active connections: connection_id -> WebSocket self.connections: Dict[str, WebSocket] = {} # Subscriptions: symbol -> set of connection_ids self.subscriptions: Dict[str, Set[str]] = {} # Connection metadata: connection_id -> metadata self.connection_metadata: Dict[str, Dict[str, Any]] = {} self.response_formatter = ResponseFormatter() self.connection_counter = 0 logger.info("WebSocket manager initialized") async def connect(self, websocket: WebSocket, client_ip: str) -> str: """ Accept new WebSocket connection. Args: websocket: WebSocket connection client_ip: Client IP address Returns: str: Connection ID """ await websocket.accept() # Generate connection ID self.connection_counter += 1 connection_id = f"ws_{self.connection_counter}_{client_ip}" # Store connection self.connections[connection_id] = websocket self.connection_metadata[connection_id] = { 'client_ip': client_ip, 'connected_at': asyncio.get_event_loop().time(), 'subscriptions': set(), 'messages_sent': 0 } logger.info(f"WebSocket connected: {connection_id}") # Send welcome message welcome_msg = self.response_formatter.success( data={'connection_id': connection_id}, message="WebSocket connected successfully" ) await self._send_to_connection(connection_id, welcome_msg) return connection_id async def disconnect(self, connection_id: str) -> None: """ Handle WebSocket disconnection. Args: connection_id: Connection ID to disconnect """ if connection_id in self.connections: # Remove from all subscriptions metadata = self.connection_metadata.get(connection_id, {}) for symbol in metadata.get('subscriptions', set()): await self._unsubscribe_connection(connection_id, symbol) # Remove connection del self.connections[connection_id] del self.connection_metadata[connection_id] logger.info(f"WebSocket disconnected: {connection_id}") async def subscribe(self, connection_id: str, symbol: str, data_type: str = "heatmap") -> bool: """ Subscribe connection to symbol updates. Args: connection_id: Connection ID symbol: Trading symbol data_type: Type of data to subscribe to Returns: bool: True if subscribed successfully """ try: # Validate symbol if not validate_symbol(symbol): error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format") await self._send_to_connection(connection_id, error_msg) return False symbol = symbol.upper() subscription_key = f"{symbol}:{data_type}" # Add to subscriptions if subscription_key not in self.subscriptions: self.subscriptions[subscription_key] = set() self.subscriptions[subscription_key].add(connection_id) # Update connection metadata if connection_id in self.connection_metadata: self.connection_metadata[connection_id]['subscriptions'].add(subscription_key) logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}") # Send confirmation confirm_msg = self.response_formatter.success( data={'symbol': symbol, 'data_type': data_type}, message=f"Subscribed to {symbol} {data_type} updates" ) await self._send_to_connection(connection_id, confirm_msg) # Send initial data if available await self._send_initial_data(connection_id, symbol, data_type) return True except Exception as e: logger.error(f"Error subscribing {connection_id} to {symbol}: {e}") error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR") await self._send_to_connection(connection_id, error_msg) return False async def unsubscribe(self, connection_id: str, symbol: str, data_type: str = "heatmap") -> bool: """ Unsubscribe connection from symbol updates. Args: connection_id: Connection ID symbol: Trading symbol data_type: Type of data to unsubscribe from Returns: bool: True if unsubscribed successfully """ try: symbol = symbol.upper() subscription_key = f"{symbol}:{data_type}" await self._unsubscribe_connection(connection_id, subscription_key) # Send confirmation confirm_msg = self.response_formatter.success( data={'symbol': symbol, 'data_type': data_type}, message=f"Unsubscribed from {symbol} {data_type} updates" ) await self._send_to_connection(connection_id, confirm_msg) return True except Exception as e: logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}") return False async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int: """ Broadcast data update to all subscribers. Args: symbol: Trading symbol data_type: Type of data data: Data to broadcast Returns: int: Number of connections notified """ try: set_correlation_id() subscription_key = f"{symbol.upper()}:{data_type}" subscribers = self.subscriptions.get(subscription_key, set()) if not subscribers: return 0 # Format message based on data type if data_type == "heatmap": message = self.response_formatter.heatmap_response(data, symbol) elif data_type == "orderbook": message = self.response_formatter.orderbook_response(data, symbol, "consolidated") else: message = self.response_formatter.success(data, f"{data_type} update for {symbol}") # Add update type to message message['update_type'] = data_type message['symbol'] = symbol # Send to all subscribers sent_count = 0 for connection_id in subscribers.copy(): # Copy to avoid modification during iteration if await self._send_to_connection(connection_id, message): sent_count += 1 logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections") return sent_count except Exception as e: logger.error(f"Error broadcasting update for {symbol}: {e}") return 0 async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool: """ Send message to specific connection. Args: connection_id: Connection ID message: Message to send Returns: bool: True if sent successfully """ try: if connection_id not in self.connections: return False websocket = self.connections[connection_id] message_json = json.dumps(message, default=str) await websocket.send_text(message_json) # Update statistics if connection_id in self.connection_metadata: self.connection_metadata[connection_id]['messages_sent'] += 1 return True except Exception as e: logger.warning(f"Error sending message to {connection_id}: {e}") # Remove broken connection await self.disconnect(connection_id) return False async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None: """Remove connection from subscription""" if subscription_key in self.subscriptions: self.subscriptions[subscription_key].discard(connection_id) # Clean up empty subscriptions if not self.subscriptions[subscription_key]: del self.subscriptions[subscription_key] # Update connection metadata if connection_id in self.connection_metadata: self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key) async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None: """Send initial data to newly subscribed connection""" try: if data_type == "heatmap": # Get latest heatmap from cache heatmap_data = await redis_manager.get_heatmap(symbol) if heatmap_data: message = self.response_formatter.heatmap_response(heatmap_data, symbol) message['update_type'] = 'initial_data' await self._send_to_connection(connection_id, message) elif data_type == "orderbook": # Could get latest order book from cache # This would require knowing which exchange to get data from pass except Exception as e: logger.warning(f"Error sending initial data to {connection_id}: {e}") def get_stats(self) -> Dict[str, Any]: """Get WebSocket manager statistics""" total_subscriptions = sum(len(subs) for subs in self.subscriptions.values()) return { 'active_connections': len(self.connections), 'total_subscriptions': total_subscriptions, 'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())), 'connection_counter': self.connection_counter } # Global WebSocket manager instance websocket_manager = WebSocketManager() class WebSocketServer: """ WebSocket server for real-time data streaming. """ def __init__(self): """Initialize WebSocket server""" self.manager = websocket_manager logger.info("WebSocket server initialized") async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None: """ Handle WebSocket connection lifecycle. Args: websocket: WebSocket connection client_ip: Client IP address """ connection_id = None try: # Accept connection connection_id = await self.manager.connect(websocket, client_ip) # Handle messages while True: try: # Receive message message = await websocket.receive_text() await self._handle_message(connection_id, message) except WebSocketDisconnect: logger.info(f"WebSocket client disconnected: {connection_id}") break except Exception as e: logger.error(f"WebSocket connection error: {e}") finally: # Clean up connection if connection_id: await self.manager.disconnect(connection_id) async def _handle_message(self, connection_id: str, message: str) -> None: """ Handle incoming WebSocket message. Args: connection_id: Connection ID message: Received message """ try: # Parse message data = json.loads(message) action = data.get('action') if action == 'subscribe': symbol = data.get('symbol') data_type = data.get('data_type', 'heatmap') await self.manager.subscribe(connection_id, symbol, data_type) elif action == 'unsubscribe': symbol = data.get('symbol') data_type = data.get('data_type', 'heatmap') await self.manager.unsubscribe(connection_id, symbol, data_type) elif action == 'ping': # Send pong response pong_msg = self.manager.response_formatter.success( data={'action': 'pong'}, message="Pong" ) await self.manager._send_to_connection(connection_id, pong_msg) else: # Unknown action error_msg = self.manager.response_formatter.error( f"Unknown action: {action}", "UNKNOWN_ACTION" ) await self.manager._send_to_connection(connection_id, error_msg) except json.JSONDecodeError: error_msg = self.manager.response_formatter.error( "Invalid JSON message", "INVALID_JSON" ) await self.manager._send_to_connection(connection_id, error_msg) except Exception as e: logger.error(f"Error handling WebSocket message: {e}") error_msg = self.manager.response_formatter.error( "Message processing failed", "MESSAGE_ERROR" ) await self.manager._send_to_connection(connection_id, error_msg)