""" WebSocket server for real-time replay data streaming. """ import asyncio import json import logging from typing import Dict, Set, Optional, Any from fastapi import WebSocket, WebSocketDisconnect from datetime import datetime from ..replay.replay_manager import HistoricalReplayManager from ..models.core import OrderBookSnapshot, TradeEvent, ReplayStatus from ..utils.logging import get_logger, set_correlation_id from ..utils.exceptions import ReplayError logger = get_logger(__name__) class ReplayWebSocketManager: """ WebSocket manager for replay data streaming. Provides: - Real-time replay data streaming - Session-based connections - Automatic cleanup on disconnect - Status updates """ def __init__(self, replay_manager: HistoricalReplayManager): """ Initialize WebSocket manager. Args: replay_manager: Replay manager instance """ self.replay_manager = replay_manager # Connection management self.connections: Dict[str, Set[WebSocket]] = {} # session_id -> websockets self.websocket_sessions: Dict[WebSocket, str] = {} # websocket -> session_id # Statistics self.stats = { 'active_connections': 0, 'total_connections': 0, 'messages_sent': 0, 'connection_errors': 0 } logger.info("Replay WebSocket manager initialized") async def connect_to_session(self, websocket: WebSocket, session_id: str) -> bool: """ Connect WebSocket to a replay session. Args: websocket: WebSocket connection session_id: Replay session ID Returns: bool: True if connected successfully, False otherwise """ try: set_correlation_id() # Check if session exists session = self.replay_manager.get_replay_status(session_id) if not session: await websocket.send_json({ "type": "error", "message": f"Session {session_id} not found" }) return False # Accept WebSocket connection await websocket.accept() # Add to connection tracking if session_id not in self.connections: self.connections[session_id] = set() self.connections[session_id].add(websocket) self.websocket_sessions[websocket] = session_id # Update statistics self.stats['active_connections'] += 1 self.stats['total_connections'] += 1 # Add callbacks to replay session self.replay_manager.add_data_callback(session_id, self._data_callback) self.replay_manager.add_status_callback(session_id, self._status_callback) # Send initial session status await self._send_session_status(websocket, session) logger.info(f"WebSocket connected to replay session {session_id}") return True except Exception as e: logger.error(f"Failed to connect WebSocket to session {session_id}: {e}") self.stats['connection_errors'] += 1 return False async def disconnect(self, websocket: WebSocket) -> None: """ Disconnect WebSocket and cleanup. Args: websocket: WebSocket connection to disconnect """ try: session_id = self.websocket_sessions.get(websocket) if session_id: # Remove from connection tracking if session_id in self.connections: self.connections[session_id].discard(websocket) # Clean up empty session connections if not self.connections[session_id]: del self.connections[session_id] del self.websocket_sessions[websocket] # Update statistics self.stats['active_connections'] -= 1 logger.info(f"WebSocket disconnected from replay session {session_id}") except Exception as e: logger.error(f"Error during WebSocket disconnect: {e}") async def handle_websocket_messages(self, websocket: WebSocket) -> None: """ Handle incoming WebSocket messages. Args: websocket: WebSocket connection """ try: while True: # Receive message message = await websocket.receive_json() # Process message await self._process_websocket_message(websocket, message) except WebSocketDisconnect: logger.info("WebSocket disconnected") except Exception as e: logger.error(f"WebSocket message handling error: {e}") await websocket.send_json({ "type": "error", "message": "Message processing error" }) async def _process_websocket_message(self, websocket: WebSocket, message: Dict[str, Any]) -> None: """ Process incoming WebSocket message. Args: websocket: WebSocket connection message: Received message """ try: message_type = message.get('type') session_id = self.websocket_sessions.get(websocket) if not session_id: await websocket.send_json({ "type": "error", "message": "Not connected to any session" }) return if message_type == "control": await self._handle_control_message(websocket, session_id, message) elif message_type == "seek": await self._handle_seek_message(websocket, session_id, message) elif message_type == "speed": await self._handle_speed_message(websocket, session_id, message) elif message_type == "status": await self._handle_status_request(websocket, session_id) else: await websocket.send_json({ "type": "error", "message": f"Unknown message type: {message_type}" }) except Exception as e: logger.error(f"Error processing WebSocket message: {e}") await websocket.send_json({ "type": "error", "message": "Message processing failed" }) async def _handle_control_message(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]) -> None: """Handle replay control messages.""" try: action = message.get('action') if action == "start": await self.replay_manager.start_replay(session_id) elif action == "pause": await self.replay_manager.pause_replay(session_id) elif action == "resume": await self.replay_manager.resume_replay(session_id) elif action == "stop": await self.replay_manager.stop_replay(session_id) else: await websocket.send_json({ "type": "error", "message": f"Invalid control action: {action}" }) return await websocket.send_json({ "type": "control_response", "action": action, "status": "success" }) except ReplayError as e: await websocket.send_json({ "type": "error", "message": str(e) }) except Exception as e: logger.error(f"Control message error: {e}") await websocket.send_json({ "type": "error", "message": "Control action failed" }) async def _handle_seek_message(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]) -> None: """Handle seek messages.""" try: timestamp_str = message.get('timestamp') if not timestamp_str: await websocket.send_json({ "type": "error", "message": "Timestamp required for seek" }) return timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) success = self.replay_manager.seek_replay(session_id, timestamp) await websocket.send_json({ "type": "seek_response", "timestamp": timestamp_str, "status": "success" if success else "failed" }) except Exception as e: logger.error(f"Seek message error: {e}") await websocket.send_json({ "type": "error", "message": "Seek failed" }) async def _handle_speed_message(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]) -> None: """Handle speed change messages.""" try: speed = message.get('speed') if not speed or speed <= 0: await websocket.send_json({ "type": "error", "message": "Valid speed required" }) return success = self.replay_manager.set_replay_speed(session_id, speed) await websocket.send_json({ "type": "speed_response", "speed": speed, "status": "success" if success else "failed" }) except Exception as e: logger.error(f"Speed message error: {e}") await websocket.send_json({ "type": "error", "message": "Speed change failed" }) async def _handle_status_request(self, websocket: WebSocket, session_id: str) -> None: """Handle status request messages.""" try: session = self.replay_manager.get_replay_status(session_id) if session: await self._send_session_status(websocket, session) else: await websocket.send_json({ "type": "error", "message": "Session not found" }) except Exception as e: logger.error(f"Status request error: {e}") await websocket.send_json({ "type": "error", "message": "Status request failed" }) async def _data_callback(self, data) -> None: """Callback for replay data - broadcasts to all connected WebSockets.""" try: # Determine which session this data belongs to # This is a simplified approach - in practice, you'd need to track # which session generated this callback # Serialize data if isinstance(data, OrderBookSnapshot): message = { "type": "orderbook", "data": { "symbol": data.symbol, "exchange": data.exchange, "timestamp": data.timestamp.isoformat(), "bids": [{"price": b.price, "size": b.size} for b in data.bids[:10]], "asks": [{"price": a.price, "size": a.size} for a in data.asks[:10]], "sequence_id": data.sequence_id } } elif isinstance(data, TradeEvent): message = { "type": "trade", "data": { "symbol": data.symbol, "exchange": data.exchange, "timestamp": data.timestamp.isoformat(), "price": data.price, "size": data.size, "side": data.side, "trade_id": data.trade_id } } else: return # Broadcast to all connections await self._broadcast_message(message) except Exception as e: logger.error(f"Data callback error: {e}") async def _status_callback(self, session_id: str, status: ReplayStatus) -> None: """Callback for replay status changes.""" try: message = { "type": "status", "session_id": session_id, "status": status.value, "timestamp": datetime.utcnow().isoformat() } # Send to connections for this session if session_id in self.connections: await self._broadcast_to_session(session_id, message) except Exception as e: logger.error(f"Status callback error: {e}") async def _send_session_status(self, websocket: WebSocket, session) -> None: """Send session status to WebSocket.""" try: message = { "type": "session_status", "data": { "session_id": session.session_id, "status": session.status.value, "progress": session.progress, "current_time": session.current_time.isoformat(), "speed": session.speed, "events_replayed": session.events_replayed, "total_events": session.total_events } } await websocket.send_json(message) self.stats['messages_sent'] += 1 except Exception as e: logger.error(f"Error sending session status: {e}") async def _broadcast_message(self, message: Dict[str, Any]) -> None: """Broadcast message to all connected WebSockets.""" disconnected = [] for session_id, websockets in self.connections.items(): for websocket in websockets.copy(): try: await websocket.send_json(message) self.stats['messages_sent'] += 1 except Exception as e: logger.warning(f"Failed to send message to WebSocket: {e}") disconnected.append((session_id, websocket)) # Clean up disconnected WebSockets for session_id, websocket in disconnected: await self.disconnect(websocket) async def _broadcast_to_session(self, session_id: str, message: Dict[str, Any]) -> None: """Broadcast message to WebSockets connected to a specific session.""" if session_id not in self.connections: return disconnected = [] for websocket in self.connections[session_id].copy(): try: await websocket.send_json(message) self.stats['messages_sent'] += 1 except Exception as e: logger.warning(f"Failed to send message to WebSocket: {e}") disconnected.append(websocket) # Clean up disconnected WebSockets for websocket in disconnected: await self.disconnect(websocket) def get_stats(self) -> Dict[str, Any]: """Get WebSocket manager statistics.""" return { **self.stats, 'sessions_with_connections': len(self.connections), 'total_websockets': sum(len(ws_set) for ws_set in self.connections.values()) }