Files
gogo2/COBY/api/websocket_server.py
Dobromir Popov fd6ec4eb40 api
2025-08-04 18:38:51 +03:00

400 lines
15 KiB
Python

"""
WebSocket server for real-time data streaming.
"""
import asyncio
import json
from typing import Dict, Set, Optional, Any
from fastapi import WebSocket, WebSocketDisconnect
from ..utils.logging import get_logger, set_correlation_id
from ..utils.validation import validate_symbol
from ..caching.redis_manager import redis_manager
from .response_formatter import ResponseFormatter
logger = get_logger(__name__)
class WebSocketManager:
"""
Manages WebSocket connections and real-time data streaming.
"""
def __init__(self):
"""Initialize WebSocket manager"""
# Active connections: connection_id -> WebSocket
self.connections: Dict[str, WebSocket] = {}
# Subscriptions: symbol -> set of connection_ids
self.subscriptions: Dict[str, Set[str]] = {}
# Connection metadata: connection_id -> metadata
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
self.response_formatter = ResponseFormatter()
self.connection_counter = 0
logger.info("WebSocket manager initialized")
async def connect(self, websocket: WebSocket, client_ip: str) -> str:
"""
Accept new WebSocket connection.
Args:
websocket: WebSocket connection
client_ip: Client IP address
Returns:
str: Connection ID
"""
await websocket.accept()
# Generate connection ID
self.connection_counter += 1
connection_id = f"ws_{self.connection_counter}_{client_ip}"
# Store connection
self.connections[connection_id] = websocket
self.connection_metadata[connection_id] = {
'client_ip': client_ip,
'connected_at': asyncio.get_event_loop().time(),
'subscriptions': set(),
'messages_sent': 0
}
logger.info(f"WebSocket connected: {connection_id}")
# Send welcome message
welcome_msg = self.response_formatter.success(
data={'connection_id': connection_id},
message="WebSocket connected successfully"
)
await self._send_to_connection(connection_id, welcome_msg)
return connection_id
async def disconnect(self, connection_id: str) -> None:
"""
Handle WebSocket disconnection.
Args:
connection_id: Connection ID to disconnect
"""
if connection_id in self.connections:
# Remove from all subscriptions
metadata = self.connection_metadata.get(connection_id, {})
for symbol in metadata.get('subscriptions', set()):
await self._unsubscribe_connection(connection_id, symbol)
# Remove connection
del self.connections[connection_id]
del self.connection_metadata[connection_id]
logger.info(f"WebSocket disconnected: {connection_id}")
async def subscribe(self, connection_id: str, symbol: str,
data_type: str = "heatmap") -> bool:
"""
Subscribe connection to symbol updates.
Args:
connection_id: Connection ID
symbol: Trading symbol
data_type: Type of data to subscribe to
Returns:
bool: True if subscribed successfully
"""
try:
# Validate symbol
if not validate_symbol(symbol):
error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format")
await self._send_to_connection(connection_id, error_msg)
return False
symbol = symbol.upper()
subscription_key = f"{symbol}:{data_type}"
# Add to subscriptions
if subscription_key not in self.subscriptions:
self.subscriptions[subscription_key] = set()
self.subscriptions[subscription_key].add(connection_id)
# Update connection metadata
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['subscriptions'].add(subscription_key)
logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}")
# Send confirmation
confirm_msg = self.response_formatter.success(
data={'symbol': symbol, 'data_type': data_type},
message=f"Subscribed to {symbol} {data_type} updates"
)
await self._send_to_connection(connection_id, confirm_msg)
# Send initial data if available
await self._send_initial_data(connection_id, symbol, data_type)
return True
except Exception as e:
logger.error(f"Error subscribing {connection_id} to {symbol}: {e}")
error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR")
await self._send_to_connection(connection_id, error_msg)
return False
async def unsubscribe(self, connection_id: str, symbol: str,
data_type: str = "heatmap") -> bool:
"""
Unsubscribe connection from symbol updates.
Args:
connection_id: Connection ID
symbol: Trading symbol
data_type: Type of data to unsubscribe from
Returns:
bool: True if unsubscribed successfully
"""
try:
symbol = symbol.upper()
subscription_key = f"{symbol}:{data_type}"
await self._unsubscribe_connection(connection_id, subscription_key)
# Send confirmation
confirm_msg = self.response_formatter.success(
data={'symbol': symbol, 'data_type': data_type},
message=f"Unsubscribed from {symbol} {data_type} updates"
)
await self._send_to_connection(connection_id, confirm_msg)
return True
except Exception as e:
logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}")
return False
async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int:
"""
Broadcast data update to all subscribers.
Args:
symbol: Trading symbol
data_type: Type of data
data: Data to broadcast
Returns:
int: Number of connections notified
"""
try:
set_correlation_id()
subscription_key = f"{symbol.upper()}:{data_type}"
subscribers = self.subscriptions.get(subscription_key, set())
if not subscribers:
return 0
# Format message based on data type
if data_type == "heatmap":
message = self.response_formatter.heatmap_response(data, symbol)
elif data_type == "orderbook":
message = self.response_formatter.orderbook_response(data, symbol, "consolidated")
else:
message = self.response_formatter.success(data, f"{data_type} update for {symbol}")
# Add update type to message
message['update_type'] = data_type
message['symbol'] = symbol
# Send to all subscribers
sent_count = 0
for connection_id in subscribers.copy(): # Copy to avoid modification during iteration
if await self._send_to_connection(connection_id, message):
sent_count += 1
logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections")
return sent_count
except Exception as e:
logger.error(f"Error broadcasting update for {symbol}: {e}")
return 0
async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool:
"""
Send message to specific connection.
Args:
connection_id: Connection ID
message: Message to send
Returns:
bool: True if sent successfully
"""
try:
if connection_id not in self.connections:
return False
websocket = self.connections[connection_id]
message_json = json.dumps(message, default=str)
await websocket.send_text(message_json)
# Update statistics
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['messages_sent'] += 1
return True
except Exception as e:
logger.warning(f"Error sending message to {connection_id}: {e}")
# Remove broken connection
await self.disconnect(connection_id)
return False
async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None:
"""Remove connection from subscription"""
if subscription_key in self.subscriptions:
self.subscriptions[subscription_key].discard(connection_id)
# Clean up empty subscriptions
if not self.subscriptions[subscription_key]:
del self.subscriptions[subscription_key]
# Update connection metadata
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key)
async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None:
"""Send initial data to newly subscribed connection"""
try:
if data_type == "heatmap":
# Get latest heatmap from cache
heatmap_data = await redis_manager.get_heatmap(symbol)
if heatmap_data:
message = self.response_formatter.heatmap_response(heatmap_data, symbol)
message['update_type'] = 'initial_data'
await self._send_to_connection(connection_id, message)
elif data_type == "orderbook":
# Could get latest order book from cache
# This would require knowing which exchange to get data from
pass
except Exception as e:
logger.warning(f"Error sending initial data to {connection_id}: {e}")
def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket manager statistics"""
total_subscriptions = sum(len(subs) for subs in self.subscriptions.values())
return {
'active_connections': len(self.connections),
'total_subscriptions': total_subscriptions,
'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())),
'connection_counter': self.connection_counter
}
# Global WebSocket manager instance
websocket_manager = WebSocketManager()
class WebSocketServer:
"""
WebSocket server for real-time data streaming.
"""
def __init__(self):
"""Initialize WebSocket server"""
self.manager = websocket_manager
logger.info("WebSocket server initialized")
async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None:
"""
Handle WebSocket connection lifecycle.
Args:
websocket: WebSocket connection
client_ip: Client IP address
"""
connection_id = None
try:
# Accept connection
connection_id = await self.manager.connect(websocket, client_ip)
# Handle messages
while True:
try:
# Receive message
message = await websocket.receive_text()
await self._handle_message(connection_id, message)
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected: {connection_id}")
break
except Exception as e:
logger.error(f"WebSocket connection error: {e}")
finally:
# Clean up connection
if connection_id:
await self.manager.disconnect(connection_id)
async def _handle_message(self, connection_id: str, message: str) -> None:
"""
Handle incoming WebSocket message.
Args:
connection_id: Connection ID
message: Received message
"""
try:
# Parse message
data = json.loads(message)
action = data.get('action')
if action == 'subscribe':
symbol = data.get('symbol')
data_type = data.get('data_type', 'heatmap')
await self.manager.subscribe(connection_id, symbol, data_type)
elif action == 'unsubscribe':
symbol = data.get('symbol')
data_type = data.get('data_type', 'heatmap')
await self.manager.unsubscribe(connection_id, symbol, data_type)
elif action == 'ping':
# Send pong response
pong_msg = self.manager.response_formatter.success(
data={'action': 'pong'},
message="Pong"
)
await self.manager._send_to_connection(connection_id, pong_msg)
else:
# Unknown action
error_msg = self.manager.response_formatter.error(
f"Unknown action: {action}",
"UNKNOWN_ACTION"
)
await self.manager._send_to_connection(connection_id, error_msg)
except json.JSONDecodeError:
error_msg = self.manager.response_formatter.error(
"Invalid JSON message",
"INVALID_JSON"
)
await self.manager._send_to_connection(connection_id, error_msg)
except Exception as e:
logger.error(f"Error handling WebSocket message: {e}")
error_msg = self.manager.response_formatter.error(
"Message processing failed",
"MESSAGE_ERROR"
)
await self.manager._send_to_connection(connection_id, error_msg)