replay system
This commit is contained in:
435
COBY/api/replay_websocket.py
Normal file
435
COBY/api/replay_websocket.py
Normal file
@ -0,0 +1,435 @@
|
||||
"""
|
||||
WebSocket server for real-time replay data streaming.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from datetime import datetime
|
||||
|
||||
from ..replay.replay_manager import HistoricalReplayManager
|
||||
from ..models.core import OrderBookSnapshot, TradeEvent, ReplayStatus
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.exceptions import ReplayError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ReplayWebSocketManager:
|
||||
"""
|
||||
WebSocket manager for replay data streaming.
|
||||
|
||||
Provides:
|
||||
- Real-time replay data streaming
|
||||
- Session-based connections
|
||||
- Automatic cleanup on disconnect
|
||||
- Status updates
|
||||
"""
|
||||
|
||||
def __init__(self, replay_manager: HistoricalReplayManager):
|
||||
"""
|
||||
Initialize WebSocket manager.
|
||||
|
||||
Args:
|
||||
replay_manager: Replay manager instance
|
||||
"""
|
||||
self.replay_manager = replay_manager
|
||||
|
||||
# Connection management
|
||||
self.connections: Dict[str, Set[WebSocket]] = {} # session_id -> websockets
|
||||
self.websocket_sessions: Dict[WebSocket, str] = {} # websocket -> session_id
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'active_connections': 0,
|
||||
'total_connections': 0,
|
||||
'messages_sent': 0,
|
||||
'connection_errors': 0
|
||||
}
|
||||
|
||||
logger.info("Replay WebSocket manager initialized")
|
||||
|
||||
async def connect_to_session(self, websocket: WebSocket, session_id: str) -> bool:
|
||||
"""
|
||||
Connect WebSocket to a replay session.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
session_id: Replay session ID
|
||||
|
||||
Returns:
|
||||
bool: True if connected successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
# Check if session exists
|
||||
session = self.replay_manager.get_replay_status(session_id)
|
||||
if not session:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Session {session_id} not found"
|
||||
})
|
||||
return False
|
||||
|
||||
# Accept WebSocket connection
|
||||
await websocket.accept()
|
||||
|
||||
# Add to connection tracking
|
||||
if session_id not in self.connections:
|
||||
self.connections[session_id] = set()
|
||||
|
||||
self.connections[session_id].add(websocket)
|
||||
self.websocket_sessions[websocket] = session_id
|
||||
|
||||
# Update statistics
|
||||
self.stats['active_connections'] += 1
|
||||
self.stats['total_connections'] += 1
|
||||
|
||||
# Add callbacks to replay session
|
||||
self.replay_manager.add_data_callback(session_id, self._data_callback)
|
||||
self.replay_manager.add_status_callback(session_id, self._status_callback)
|
||||
|
||||
# Send initial session status
|
||||
await self._send_session_status(websocket, session)
|
||||
|
||||
logger.info(f"WebSocket connected to replay session {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect WebSocket to session {session_id}: {e}")
|
||||
self.stats['connection_errors'] += 1
|
||||
return False
|
||||
|
||||
async def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Disconnect WebSocket and cleanup.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection to disconnect
|
||||
"""
|
||||
try:
|
||||
session_id = self.websocket_sessions.get(websocket)
|
||||
|
||||
if session_id:
|
||||
# Remove from connection tracking
|
||||
if session_id in self.connections:
|
||||
self.connections[session_id].discard(websocket)
|
||||
|
||||
# Clean up empty session connections
|
||||
if not self.connections[session_id]:
|
||||
del self.connections[session_id]
|
||||
|
||||
del self.websocket_sessions[websocket]
|
||||
|
||||
# Update statistics
|
||||
self.stats['active_connections'] -= 1
|
||||
|
||||
logger.info(f"WebSocket disconnected from replay session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during WebSocket disconnect: {e}")
|
||||
|
||||
async def handle_websocket_messages(self, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Handle incoming WebSocket messages.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Receive message
|
||||
message = await websocket.receive_json()
|
||||
|
||||
# Process message
|
||||
await self._process_websocket_message(websocket, message)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket message handling error: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Message processing error"
|
||||
})
|
||||
|
||||
async def _process_websocket_message(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Process incoming WebSocket message.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
message: Received message
|
||||
"""
|
||||
try:
|
||||
message_type = message.get('type')
|
||||
session_id = self.websocket_sessions.get(websocket)
|
||||
|
||||
if not session_id:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Not connected to any session"
|
||||
})
|
||||
return
|
||||
|
||||
if message_type == "control":
|
||||
await self._handle_control_message(websocket, session_id, message)
|
||||
elif message_type == "seek":
|
||||
await self._handle_seek_message(websocket, session_id, message)
|
||||
elif message_type == "speed":
|
||||
await self._handle_speed_message(websocket, session_id, message)
|
||||
elif message_type == "status":
|
||||
await self._handle_status_request(websocket, session_id)
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {message_type}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket message: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Message processing failed"
|
||||
})
|
||||
|
||||
async def _handle_control_message(self, websocket: WebSocket, session_id: str,
|
||||
message: Dict[str, Any]) -> None:
|
||||
"""Handle replay control messages."""
|
||||
try:
|
||||
action = message.get('action')
|
||||
|
||||
if action == "start":
|
||||
await self.replay_manager.start_replay(session_id)
|
||||
elif action == "pause":
|
||||
await self.replay_manager.pause_replay(session_id)
|
||||
elif action == "resume":
|
||||
await self.replay_manager.resume_replay(session_id)
|
||||
elif action == "stop":
|
||||
await self.replay_manager.stop_replay(session_id)
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Invalid control action: {action}"
|
||||
})
|
||||
return
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "control_response",
|
||||
"action": action,
|
||||
"status": "success"
|
||||
})
|
||||
|
||||
except ReplayError as e:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Control message error: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Control action failed"
|
||||
})
|
||||
|
||||
async def _handle_seek_message(self, websocket: WebSocket, session_id: str,
|
||||
message: Dict[str, Any]) -> None:
|
||||
"""Handle seek messages."""
|
||||
try:
|
||||
timestamp_str = message.get('timestamp')
|
||||
if not timestamp_str:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Timestamp required for seek"
|
||||
})
|
||||
return
|
||||
|
||||
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
||||
success = self.replay_manager.seek_replay(session_id, timestamp)
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "seek_response",
|
||||
"timestamp": timestamp_str,
|
||||
"status": "success" if success else "failed"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Seek message error: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Seek failed"
|
||||
})
|
||||
|
||||
async def _handle_speed_message(self, websocket: WebSocket, session_id: str,
|
||||
message: Dict[str, Any]) -> None:
|
||||
"""Handle speed change messages."""
|
||||
try:
|
||||
speed = message.get('speed')
|
||||
if not speed or speed <= 0:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Valid speed required"
|
||||
})
|
||||
return
|
||||
|
||||
success = self.replay_manager.set_replay_speed(session_id, speed)
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "speed_response",
|
||||
"speed": speed,
|
||||
"status": "success" if success else "failed"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Speed message error: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Speed change failed"
|
||||
})
|
||||
|
||||
async def _handle_status_request(self, websocket: WebSocket, session_id: str) -> None:
|
||||
"""Handle status request messages."""
|
||||
try:
|
||||
session = self.replay_manager.get_replay_status(session_id)
|
||||
if session:
|
||||
await self._send_session_status(websocket, session)
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Session not found"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Status request error: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Status request failed"
|
||||
})
|
||||
|
||||
async def _data_callback(self, data) -> None:
|
||||
"""Callback for replay data - broadcasts to all connected WebSockets."""
|
||||
try:
|
||||
# Determine which session this data belongs to
|
||||
# This is a simplified approach - in practice, you'd need to track
|
||||
# which session generated this callback
|
||||
|
||||
# Serialize data
|
||||
if isinstance(data, OrderBookSnapshot):
|
||||
message = {
|
||||
"type": "orderbook",
|
||||
"data": {
|
||||
"symbol": data.symbol,
|
||||
"exchange": data.exchange,
|
||||
"timestamp": data.timestamp.isoformat(),
|
||||
"bids": [{"price": b.price, "size": b.size} for b in data.bids[:10]],
|
||||
"asks": [{"price": a.price, "size": a.size} for a in data.asks[:10]],
|
||||
"sequence_id": data.sequence_id
|
||||
}
|
||||
}
|
||||
elif isinstance(data, TradeEvent):
|
||||
message = {
|
||||
"type": "trade",
|
||||
"data": {
|
||||
"symbol": data.symbol,
|
||||
"exchange": data.exchange,
|
||||
"timestamp": data.timestamp.isoformat(),
|
||||
"price": data.price,
|
||||
"size": data.size,
|
||||
"side": data.side,
|
||||
"trade_id": data.trade_id
|
||||
}
|
||||
}
|
||||
else:
|
||||
return
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data callback error: {e}")
|
||||
|
||||
async def _status_callback(self, session_id: str, status: ReplayStatus) -> None:
|
||||
"""Callback for replay status changes."""
|
||||
try:
|
||||
message = {
|
||||
"type": "status",
|
||||
"session_id": session_id,
|
||||
"status": status.value,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Send to connections for this session
|
||||
if session_id in self.connections:
|
||||
await self._broadcast_to_session(session_id, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Status callback error: {e}")
|
||||
|
||||
async def _send_session_status(self, websocket: WebSocket, session) -> None:
|
||||
"""Send session status to WebSocket."""
|
||||
try:
|
||||
message = {
|
||||
"type": "session_status",
|
||||
"data": {
|
||||
"session_id": session.session_id,
|
||||
"status": session.status.value,
|
||||
"progress": session.progress,
|
||||
"current_time": session.current_time.isoformat(),
|
||||
"speed": session.speed,
|
||||
"events_replayed": session.events_replayed,
|
||||
"total_events": session.total_events
|
||||
}
|
||||
}
|
||||
|
||||
await websocket.send_json(message)
|
||||
self.stats['messages_sent'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending session status: {e}")
|
||||
|
||||
async def _broadcast_message(self, message: Dict[str, Any]) -> None:
|
||||
"""Broadcast message to all connected WebSockets."""
|
||||
disconnected = []
|
||||
|
||||
for session_id, websockets in self.connections.items():
|
||||
for websocket in websockets.copy():
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
self.stats['messages_sent'] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send message to WebSocket: {e}")
|
||||
disconnected.append((session_id, websocket))
|
||||
|
||||
# Clean up disconnected WebSockets
|
||||
for session_id, websocket in disconnected:
|
||||
await self.disconnect(websocket)
|
||||
|
||||
async def _broadcast_to_session(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
"""Broadcast message to WebSockets connected to a specific session."""
|
||||
if session_id not in self.connections:
|
||||
return
|
||||
|
||||
disconnected = []
|
||||
|
||||
for websocket in self.connections[session_id].copy():
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
self.stats['messages_sent'] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send message to WebSocket: {e}")
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected WebSockets
|
||||
for websocket in disconnected:
|
||||
await self.disconnect(websocket)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get WebSocket manager statistics."""
|
||||
return {
|
||||
**self.stats,
|
||||
'sessions_with_connections': len(self.connections),
|
||||
'total_websockets': sum(len(ws_set) for ws_set in self.connections.values())
|
||||
}
|
Reference in New Issue
Block a user