435 lines
16 KiB
Python
435 lines
16 KiB
Python
"""
|
|
WebSocket server for real-time replay data streaming.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, Set, Optional, Any
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from datetime import datetime
|
|
|
|
from ..replay.replay_manager import HistoricalReplayManager
|
|
from ..models.core import OrderBookSnapshot, TradeEvent, ReplayStatus
|
|
from ..utils.logging import get_logger, set_correlation_id
|
|
from ..utils.exceptions import ReplayError
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ReplayWebSocketManager:
|
|
"""
|
|
WebSocket manager for replay data streaming.
|
|
|
|
Provides:
|
|
- Real-time replay data streaming
|
|
- Session-based connections
|
|
- Automatic cleanup on disconnect
|
|
- Status updates
|
|
"""
|
|
|
|
def __init__(self, replay_manager: HistoricalReplayManager):
|
|
"""
|
|
Initialize WebSocket manager.
|
|
|
|
Args:
|
|
replay_manager: Replay manager instance
|
|
"""
|
|
self.replay_manager = replay_manager
|
|
|
|
# Connection management
|
|
self.connections: Dict[str, Set[WebSocket]] = {} # session_id -> websockets
|
|
self.websocket_sessions: Dict[WebSocket, str] = {} # websocket -> session_id
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
'active_connections': 0,
|
|
'total_connections': 0,
|
|
'messages_sent': 0,
|
|
'connection_errors': 0
|
|
}
|
|
|
|
logger.info("Replay WebSocket manager initialized")
|
|
|
|
async def connect_to_session(self, websocket: WebSocket, session_id: str) -> bool:
|
|
"""
|
|
Connect WebSocket to a replay session.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
session_id: Replay session ID
|
|
|
|
Returns:
|
|
bool: True if connected successfully, False otherwise
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
|
|
# Check if session exists
|
|
session = self.replay_manager.get_replay_status(session_id)
|
|
if not session:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": f"Session {session_id} not found"
|
|
})
|
|
return False
|
|
|
|
# Accept WebSocket connection
|
|
await websocket.accept()
|
|
|
|
# Add to connection tracking
|
|
if session_id not in self.connections:
|
|
self.connections[session_id] = set()
|
|
|
|
self.connections[session_id].add(websocket)
|
|
self.websocket_sessions[websocket] = session_id
|
|
|
|
# Update statistics
|
|
self.stats['active_connections'] += 1
|
|
self.stats['total_connections'] += 1
|
|
|
|
# Add callbacks to replay session
|
|
self.replay_manager.add_data_callback(session_id, self._data_callback)
|
|
self.replay_manager.add_status_callback(session_id, self._status_callback)
|
|
|
|
# Send initial session status
|
|
await self._send_session_status(websocket, session)
|
|
|
|
logger.info(f"WebSocket connected to replay session {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect WebSocket to session {session_id}: {e}")
|
|
self.stats['connection_errors'] += 1
|
|
return False
|
|
|
|
async def disconnect(self, websocket: WebSocket) -> None:
|
|
"""
|
|
Disconnect WebSocket and cleanup.
|
|
|
|
Args:
|
|
websocket: WebSocket connection to disconnect
|
|
"""
|
|
try:
|
|
session_id = self.websocket_sessions.get(websocket)
|
|
|
|
if session_id:
|
|
# Remove from connection tracking
|
|
if session_id in self.connections:
|
|
self.connections[session_id].discard(websocket)
|
|
|
|
# Clean up empty session connections
|
|
if not self.connections[session_id]:
|
|
del self.connections[session_id]
|
|
|
|
del self.websocket_sessions[websocket]
|
|
|
|
# Update statistics
|
|
self.stats['active_connections'] -= 1
|
|
|
|
logger.info(f"WebSocket disconnected from replay session {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during WebSocket disconnect: {e}")
|
|
|
|
async def handle_websocket_messages(self, websocket: WebSocket) -> None:
|
|
"""
|
|
Handle incoming WebSocket messages.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
"""
|
|
try:
|
|
while True:
|
|
# Receive message
|
|
message = await websocket.receive_json()
|
|
|
|
# Process message
|
|
await self._process_websocket_message(websocket, message)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket message handling error: {e}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Message processing error"
|
|
})
|
|
|
|
async def _process_websocket_message(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
|
|
"""
|
|
Process incoming WebSocket message.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
message: Received message
|
|
"""
|
|
try:
|
|
message_type = message.get('type')
|
|
session_id = self.websocket_sessions.get(websocket)
|
|
|
|
if not session_id:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Not connected to any session"
|
|
})
|
|
return
|
|
|
|
if message_type == "control":
|
|
await self._handle_control_message(websocket, session_id, message)
|
|
elif message_type == "seek":
|
|
await self._handle_seek_message(websocket, session_id, message)
|
|
elif message_type == "speed":
|
|
await self._handle_speed_message(websocket, session_id, message)
|
|
elif message_type == "status":
|
|
await self._handle_status_request(websocket, session_id)
|
|
else:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": f"Unknown message type: {message_type}"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing WebSocket message: {e}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Message processing failed"
|
|
})
|
|
|
|
async def _handle_control_message(self, websocket: WebSocket, session_id: str,
|
|
message: Dict[str, Any]) -> None:
|
|
"""Handle replay control messages."""
|
|
try:
|
|
action = message.get('action')
|
|
|
|
if action == "start":
|
|
await self.replay_manager.start_replay(session_id)
|
|
elif action == "pause":
|
|
await self.replay_manager.pause_replay(session_id)
|
|
elif action == "resume":
|
|
await self.replay_manager.resume_replay(session_id)
|
|
elif action == "stop":
|
|
await self.replay_manager.stop_replay(session_id)
|
|
else:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": f"Invalid control action: {action}"
|
|
})
|
|
return
|
|
|
|
await websocket.send_json({
|
|
"type": "control_response",
|
|
"action": action,
|
|
"status": "success"
|
|
})
|
|
|
|
except ReplayError as e:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": str(e)
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Control message error: {e}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Control action failed"
|
|
})
|
|
|
|
async def _handle_seek_message(self, websocket: WebSocket, session_id: str,
|
|
message: Dict[str, Any]) -> None:
|
|
"""Handle seek messages."""
|
|
try:
|
|
timestamp_str = message.get('timestamp')
|
|
if not timestamp_str:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Timestamp required for seek"
|
|
})
|
|
return
|
|
|
|
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
|
success = self.replay_manager.seek_replay(session_id, timestamp)
|
|
|
|
await websocket.send_json({
|
|
"type": "seek_response",
|
|
"timestamp": timestamp_str,
|
|
"status": "success" if success else "failed"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Seek message error: {e}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Seek failed"
|
|
})
|
|
|
|
async def _handle_speed_message(self, websocket: WebSocket, session_id: str,
|
|
message: Dict[str, Any]) -> None:
|
|
"""Handle speed change messages."""
|
|
try:
|
|
speed = message.get('speed')
|
|
if not speed or speed <= 0:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Valid speed required"
|
|
})
|
|
return
|
|
|
|
success = self.replay_manager.set_replay_speed(session_id, speed)
|
|
|
|
await websocket.send_json({
|
|
"type": "speed_response",
|
|
"speed": speed,
|
|
"status": "success" if success else "failed"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Speed message error: {e}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Speed change failed"
|
|
})
|
|
|
|
async def _handle_status_request(self, websocket: WebSocket, session_id: str) -> None:
|
|
"""Handle status request messages."""
|
|
try:
|
|
session = self.replay_manager.get_replay_status(session_id)
|
|
if session:
|
|
await self._send_session_status(websocket, session)
|
|
else:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Session not found"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Status request error: {e}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Status request failed"
|
|
})
|
|
|
|
async def _data_callback(self, data) -> None:
|
|
"""Callback for replay data - broadcasts to all connected WebSockets."""
|
|
try:
|
|
# Determine which session this data belongs to
|
|
# This is a simplified approach - in practice, you'd need to track
|
|
# which session generated this callback
|
|
|
|
# Serialize data
|
|
if isinstance(data, OrderBookSnapshot):
|
|
message = {
|
|
"type": "orderbook",
|
|
"data": {
|
|
"symbol": data.symbol,
|
|
"exchange": data.exchange,
|
|
"timestamp": data.timestamp.isoformat(),
|
|
"bids": [{"price": b.price, "size": b.size} for b in data.bids[:10]],
|
|
"asks": [{"price": a.price, "size": a.size} for a in data.asks[:10]],
|
|
"sequence_id": data.sequence_id
|
|
}
|
|
}
|
|
elif isinstance(data, TradeEvent):
|
|
message = {
|
|
"type": "trade",
|
|
"data": {
|
|
"symbol": data.symbol,
|
|
"exchange": data.exchange,
|
|
"timestamp": data.timestamp.isoformat(),
|
|
"price": data.price,
|
|
"size": data.size,
|
|
"side": data.side,
|
|
"trade_id": data.trade_id
|
|
}
|
|
}
|
|
else:
|
|
return
|
|
|
|
# Broadcast to all connections
|
|
await self._broadcast_message(message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Data callback error: {e}")
|
|
|
|
async def _status_callback(self, session_id: str, status: ReplayStatus) -> None:
|
|
"""Callback for replay status changes."""
|
|
try:
|
|
message = {
|
|
"type": "status",
|
|
"session_id": session_id,
|
|
"status": status.value,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# Send to connections for this session
|
|
if session_id in self.connections:
|
|
await self._broadcast_to_session(session_id, message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Status callback error: {e}")
|
|
|
|
async def _send_session_status(self, websocket: WebSocket, session) -> None:
|
|
"""Send session status to WebSocket."""
|
|
try:
|
|
message = {
|
|
"type": "session_status",
|
|
"data": {
|
|
"session_id": session.session_id,
|
|
"status": session.status.value,
|
|
"progress": session.progress,
|
|
"current_time": session.current_time.isoformat(),
|
|
"speed": session.speed,
|
|
"events_replayed": session.events_replayed,
|
|
"total_events": session.total_events
|
|
}
|
|
}
|
|
|
|
await websocket.send_json(message)
|
|
self.stats['messages_sent'] += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending session status: {e}")
|
|
|
|
async def _broadcast_message(self, message: Dict[str, Any]) -> None:
|
|
"""Broadcast message to all connected WebSockets."""
|
|
disconnected = []
|
|
|
|
for session_id, websockets in self.connections.items():
|
|
for websocket in websockets.copy():
|
|
try:
|
|
await websocket.send_json(message)
|
|
self.stats['messages_sent'] += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send message to WebSocket: {e}")
|
|
disconnected.append((session_id, websocket))
|
|
|
|
# Clean up disconnected WebSockets
|
|
for session_id, websocket in disconnected:
|
|
await self.disconnect(websocket)
|
|
|
|
async def _broadcast_to_session(self, session_id: str, message: Dict[str, Any]) -> None:
|
|
"""Broadcast message to WebSockets connected to a specific session."""
|
|
if session_id not in self.connections:
|
|
return
|
|
|
|
disconnected = []
|
|
|
|
for websocket in self.connections[session_id].copy():
|
|
try:
|
|
await websocket.send_json(message)
|
|
self.stats['messages_sent'] += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send message to WebSocket: {e}")
|
|
disconnected.append(websocket)
|
|
|
|
# Clean up disconnected WebSockets
|
|
for websocket in disconnected:
|
|
await self.disconnect(websocket)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get WebSocket manager statistics."""
|
|
return {
|
|
**self.stats,
|
|
'sessions_with_connections': len(self.connections),
|
|
'total_websockets': sum(len(ws_set) for ws_set in self.connections.values())
|
|
} |