400 lines
15 KiB
Python
400 lines
15 KiB
Python
"""
|
|
WebSocket server for real-time data streaming.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Dict, Set, Optional, Any
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from ..utils.logging import get_logger, set_correlation_id
|
|
from ..utils.validation import validate_symbol
|
|
from ..caching.redis_manager import redis_manager
|
|
from .response_formatter import ResponseFormatter
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class WebSocketManager:
|
|
"""
|
|
Manages WebSocket connections and real-time data streaming.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize WebSocket manager"""
|
|
# Active connections: connection_id -> WebSocket
|
|
self.connections: Dict[str, WebSocket] = {}
|
|
|
|
# Subscriptions: symbol -> set of connection_ids
|
|
self.subscriptions: Dict[str, Set[str]] = {}
|
|
|
|
# Connection metadata: connection_id -> metadata
|
|
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
|
|
|
|
self.response_formatter = ResponseFormatter()
|
|
self.connection_counter = 0
|
|
|
|
logger.info("WebSocket manager initialized")
|
|
|
|
async def connect(self, websocket: WebSocket, client_ip: str) -> str:
|
|
"""
|
|
Accept new WebSocket connection.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
client_ip: Client IP address
|
|
|
|
Returns:
|
|
str: Connection ID
|
|
"""
|
|
await websocket.accept()
|
|
|
|
# Generate connection ID
|
|
self.connection_counter += 1
|
|
connection_id = f"ws_{self.connection_counter}_{client_ip}"
|
|
|
|
# Store connection
|
|
self.connections[connection_id] = websocket
|
|
self.connection_metadata[connection_id] = {
|
|
'client_ip': client_ip,
|
|
'connected_at': asyncio.get_event_loop().time(),
|
|
'subscriptions': set(),
|
|
'messages_sent': 0
|
|
}
|
|
|
|
logger.info(f"WebSocket connected: {connection_id}")
|
|
|
|
# Send welcome message
|
|
welcome_msg = self.response_formatter.success(
|
|
data={'connection_id': connection_id},
|
|
message="WebSocket connected successfully"
|
|
)
|
|
await self._send_to_connection(connection_id, welcome_msg)
|
|
|
|
return connection_id
|
|
|
|
async def disconnect(self, connection_id: str) -> None:
|
|
"""
|
|
Handle WebSocket disconnection.
|
|
|
|
Args:
|
|
connection_id: Connection ID to disconnect
|
|
"""
|
|
if connection_id in self.connections:
|
|
# Remove from all subscriptions
|
|
metadata = self.connection_metadata.get(connection_id, {})
|
|
for symbol in metadata.get('subscriptions', set()):
|
|
await self._unsubscribe_connection(connection_id, symbol)
|
|
|
|
# Remove connection
|
|
del self.connections[connection_id]
|
|
del self.connection_metadata[connection_id]
|
|
|
|
logger.info(f"WebSocket disconnected: {connection_id}")
|
|
|
|
async def subscribe(self, connection_id: str, symbol: str,
|
|
data_type: str = "heatmap") -> bool:
|
|
"""
|
|
Subscribe connection to symbol updates.
|
|
|
|
Args:
|
|
connection_id: Connection ID
|
|
symbol: Trading symbol
|
|
data_type: Type of data to subscribe to
|
|
|
|
Returns:
|
|
bool: True if subscribed successfully
|
|
"""
|
|
try:
|
|
# Validate symbol
|
|
if not validate_symbol(symbol):
|
|
error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format")
|
|
await self._send_to_connection(connection_id, error_msg)
|
|
return False
|
|
|
|
symbol = symbol.upper()
|
|
subscription_key = f"{symbol}:{data_type}"
|
|
|
|
# Add to subscriptions
|
|
if subscription_key not in self.subscriptions:
|
|
self.subscriptions[subscription_key] = set()
|
|
|
|
self.subscriptions[subscription_key].add(connection_id)
|
|
|
|
# Update connection metadata
|
|
if connection_id in self.connection_metadata:
|
|
self.connection_metadata[connection_id]['subscriptions'].add(subscription_key)
|
|
|
|
logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}")
|
|
|
|
# Send confirmation
|
|
confirm_msg = self.response_formatter.success(
|
|
data={'symbol': symbol, 'data_type': data_type},
|
|
message=f"Subscribed to {symbol} {data_type} updates"
|
|
)
|
|
await self._send_to_connection(connection_id, confirm_msg)
|
|
|
|
# Send initial data if available
|
|
await self._send_initial_data(connection_id, symbol, data_type)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error subscribing {connection_id} to {symbol}: {e}")
|
|
error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR")
|
|
await self._send_to_connection(connection_id, error_msg)
|
|
return False
|
|
|
|
async def unsubscribe(self, connection_id: str, symbol: str,
|
|
data_type: str = "heatmap") -> bool:
|
|
"""
|
|
Unsubscribe connection from symbol updates.
|
|
|
|
Args:
|
|
connection_id: Connection ID
|
|
symbol: Trading symbol
|
|
data_type: Type of data to unsubscribe from
|
|
|
|
Returns:
|
|
bool: True if unsubscribed successfully
|
|
"""
|
|
try:
|
|
symbol = symbol.upper()
|
|
subscription_key = f"{symbol}:{data_type}"
|
|
|
|
await self._unsubscribe_connection(connection_id, subscription_key)
|
|
|
|
# Send confirmation
|
|
confirm_msg = self.response_formatter.success(
|
|
data={'symbol': symbol, 'data_type': data_type},
|
|
message=f"Unsubscribed from {symbol} {data_type} updates"
|
|
)
|
|
await self._send_to_connection(connection_id, confirm_msg)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}")
|
|
return False
|
|
|
|
async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int:
|
|
"""
|
|
Broadcast data update to all subscribers.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
data_type: Type of data
|
|
data: Data to broadcast
|
|
|
|
Returns:
|
|
int: Number of connections notified
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
|
|
subscription_key = f"{symbol.upper()}:{data_type}"
|
|
subscribers = self.subscriptions.get(subscription_key, set())
|
|
|
|
if not subscribers:
|
|
return 0
|
|
|
|
# Format message based on data type
|
|
if data_type == "heatmap":
|
|
message = self.response_formatter.heatmap_response(data, symbol)
|
|
elif data_type == "orderbook":
|
|
message = self.response_formatter.orderbook_response(data, symbol, "consolidated")
|
|
else:
|
|
message = self.response_formatter.success(data, f"{data_type} update for {symbol}")
|
|
|
|
# Add update type to message
|
|
message['update_type'] = data_type
|
|
message['symbol'] = symbol
|
|
|
|
# Send to all subscribers
|
|
sent_count = 0
|
|
for connection_id in subscribers.copy(): # Copy to avoid modification during iteration
|
|
if await self._send_to_connection(connection_id, message):
|
|
sent_count += 1
|
|
|
|
logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections")
|
|
return sent_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting update for {symbol}: {e}")
|
|
return 0
|
|
|
|
async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool:
|
|
"""
|
|
Send message to specific connection.
|
|
|
|
Args:
|
|
connection_id: Connection ID
|
|
message: Message to send
|
|
|
|
Returns:
|
|
bool: True if sent successfully
|
|
"""
|
|
try:
|
|
if connection_id not in self.connections:
|
|
return False
|
|
|
|
websocket = self.connections[connection_id]
|
|
message_json = json.dumps(message, default=str)
|
|
|
|
await websocket.send_text(message_json)
|
|
|
|
# Update statistics
|
|
if connection_id in self.connection_metadata:
|
|
self.connection_metadata[connection_id]['messages_sent'] += 1
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error sending message to {connection_id}: {e}")
|
|
# Remove broken connection
|
|
await self.disconnect(connection_id)
|
|
return False
|
|
|
|
async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None:
|
|
"""Remove connection from subscription"""
|
|
if subscription_key in self.subscriptions:
|
|
self.subscriptions[subscription_key].discard(connection_id)
|
|
|
|
# Clean up empty subscriptions
|
|
if not self.subscriptions[subscription_key]:
|
|
del self.subscriptions[subscription_key]
|
|
|
|
# Update connection metadata
|
|
if connection_id in self.connection_metadata:
|
|
self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key)
|
|
|
|
async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None:
|
|
"""Send initial data to newly subscribed connection"""
|
|
try:
|
|
if data_type == "heatmap":
|
|
# Get latest heatmap from cache
|
|
heatmap_data = await redis_manager.get_heatmap(symbol)
|
|
if heatmap_data:
|
|
message = self.response_formatter.heatmap_response(heatmap_data, symbol)
|
|
message['update_type'] = 'initial_data'
|
|
await self._send_to_connection(connection_id, message)
|
|
|
|
elif data_type == "orderbook":
|
|
# Could get latest order book from cache
|
|
# This would require knowing which exchange to get data from
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error sending initial data to {connection_id}: {e}")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get WebSocket manager statistics"""
|
|
total_subscriptions = sum(len(subs) for subs in self.subscriptions.values())
|
|
|
|
return {
|
|
'active_connections': len(self.connections),
|
|
'total_subscriptions': total_subscriptions,
|
|
'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())),
|
|
'connection_counter': self.connection_counter
|
|
}
|
|
|
|
|
|
# Global WebSocket manager instance
|
|
websocket_manager = WebSocketManager()
|
|
|
|
|
|
class WebSocketServer:
|
|
"""
|
|
WebSocket server for real-time data streaming.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize WebSocket server"""
|
|
self.manager = websocket_manager
|
|
logger.info("WebSocket server initialized")
|
|
|
|
async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None:
|
|
"""
|
|
Handle WebSocket connection lifecycle.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
client_ip: Client IP address
|
|
"""
|
|
connection_id = None
|
|
|
|
try:
|
|
# Accept connection
|
|
connection_id = await self.manager.connect(websocket, client_ip)
|
|
|
|
# Handle messages
|
|
while True:
|
|
try:
|
|
# Receive message
|
|
message = await websocket.receive_text()
|
|
await self._handle_message(connection_id, message)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket client disconnected: {connection_id}")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket connection error: {e}")
|
|
|
|
finally:
|
|
# Clean up connection
|
|
if connection_id:
|
|
await self.manager.disconnect(connection_id)
|
|
|
|
async def _handle_message(self, connection_id: str, message: str) -> None:
|
|
"""
|
|
Handle incoming WebSocket message.
|
|
|
|
Args:
|
|
connection_id: Connection ID
|
|
message: Received message
|
|
"""
|
|
try:
|
|
# Parse message
|
|
data = json.loads(message)
|
|
action = data.get('action')
|
|
|
|
if action == 'subscribe':
|
|
symbol = data.get('symbol')
|
|
data_type = data.get('data_type', 'heatmap')
|
|
await self.manager.subscribe(connection_id, symbol, data_type)
|
|
|
|
elif action == 'unsubscribe':
|
|
symbol = data.get('symbol')
|
|
data_type = data.get('data_type', 'heatmap')
|
|
await self.manager.unsubscribe(connection_id, symbol, data_type)
|
|
|
|
elif action == 'ping':
|
|
# Send pong response
|
|
pong_msg = self.manager.response_formatter.success(
|
|
data={'action': 'pong'},
|
|
message="Pong"
|
|
)
|
|
await self.manager._send_to_connection(connection_id, pong_msg)
|
|
|
|
else:
|
|
# Unknown action
|
|
error_msg = self.manager.response_formatter.error(
|
|
f"Unknown action: {action}",
|
|
"UNKNOWN_ACTION"
|
|
)
|
|
await self.manager._send_to_connection(connection_id, error_msg)
|
|
|
|
except json.JSONDecodeError:
|
|
error_msg = self.manager.response_formatter.error(
|
|
"Invalid JSON message",
|
|
"INVALID_JSON"
|
|
)
|
|
await self.manager._send_to_connection(connection_id, error_msg)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling WebSocket message: {e}")
|
|
error_msg = self.manager.response_formatter.error(
|
|
"Message processing failed",
|
|
"MESSAGE_ERROR"
|
|
)
|
|
await self.manager._send_to_connection(connection_id, error_msg) |