api
This commit is contained in:
400
COBY/api/websocket_server.py
Normal file
400
COBY/api/websocket_server.py
Normal file
@ -0,0 +1,400 @@
|
||||
"""
|
||||
WebSocket server for real-time data streaming.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.validation import validate_symbol
|
||||
from ..caching.redis_manager import redis_manager
|
||||
from .response_formatter import ResponseFormatter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""
|
||||
Manages WebSocket connections and real-time data streaming.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize WebSocket manager"""
|
||||
# Active connections: connection_id -> WebSocket
|
||||
self.connections: Dict[str, WebSocket] = {}
|
||||
|
||||
# Subscriptions: symbol -> set of connection_ids
|
||||
self.subscriptions: Dict[str, Set[str]] = {}
|
||||
|
||||
# Connection metadata: connection_id -> metadata
|
||||
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
self.response_formatter = ResponseFormatter()
|
||||
self.connection_counter = 0
|
||||
|
||||
logger.info("WebSocket manager initialized")
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_ip: str) -> str:
|
||||
"""
|
||||
Accept new WebSocket connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
client_ip: Client IP address
|
||||
|
||||
Returns:
|
||||
str: Connection ID
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# Generate connection ID
|
||||
self.connection_counter += 1
|
||||
connection_id = f"ws_{self.connection_counter}_{client_ip}"
|
||||
|
||||
# Store connection
|
||||
self.connections[connection_id] = websocket
|
||||
self.connection_metadata[connection_id] = {
|
||||
'client_ip': client_ip,
|
||||
'connected_at': asyncio.get_event_loop().time(),
|
||||
'subscriptions': set(),
|
||||
'messages_sent': 0
|
||||
}
|
||||
|
||||
logger.info(f"WebSocket connected: {connection_id}")
|
||||
|
||||
# Send welcome message
|
||||
welcome_msg = self.response_formatter.success(
|
||||
data={'connection_id': connection_id},
|
||||
message="WebSocket connected successfully"
|
||||
)
|
||||
await self._send_to_connection(connection_id, welcome_msg)
|
||||
|
||||
return connection_id
|
||||
|
||||
async def disconnect(self, connection_id: str) -> None:
|
||||
"""
|
||||
Handle WebSocket disconnection.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID to disconnect
|
||||
"""
|
||||
if connection_id in self.connections:
|
||||
# Remove from all subscriptions
|
||||
metadata = self.connection_metadata.get(connection_id, {})
|
||||
for symbol in metadata.get('subscriptions', set()):
|
||||
await self._unsubscribe_connection(connection_id, symbol)
|
||||
|
||||
# Remove connection
|
||||
del self.connections[connection_id]
|
||||
del self.connection_metadata[connection_id]
|
||||
|
||||
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||
|
||||
async def subscribe(self, connection_id: str, symbol: str,
|
||||
data_type: str = "heatmap") -> bool:
|
||||
"""
|
||||
Subscribe connection to symbol updates.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID
|
||||
symbol: Trading symbol
|
||||
data_type: Type of data to subscribe to
|
||||
|
||||
Returns:
|
||||
bool: True if subscribed successfully
|
||||
"""
|
||||
try:
|
||||
# Validate symbol
|
||||
if not validate_symbol(symbol):
|
||||
error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format")
|
||||
await self._send_to_connection(connection_id, error_msg)
|
||||
return False
|
||||
|
||||
symbol = symbol.upper()
|
||||
subscription_key = f"{symbol}:{data_type}"
|
||||
|
||||
# Add to subscriptions
|
||||
if subscription_key not in self.subscriptions:
|
||||
self.subscriptions[subscription_key] = set()
|
||||
|
||||
self.subscriptions[subscription_key].add(connection_id)
|
||||
|
||||
# Update connection metadata
|
||||
if connection_id in self.connection_metadata:
|
||||
self.connection_metadata[connection_id]['subscriptions'].add(subscription_key)
|
||||
|
||||
logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}")
|
||||
|
||||
# Send confirmation
|
||||
confirm_msg = self.response_formatter.success(
|
||||
data={'symbol': symbol, 'data_type': data_type},
|
||||
message=f"Subscribed to {symbol} {data_type} updates"
|
||||
)
|
||||
await self._send_to_connection(connection_id, confirm_msg)
|
||||
|
||||
# Send initial data if available
|
||||
await self._send_initial_data(connection_id, symbol, data_type)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error subscribing {connection_id} to {symbol}: {e}")
|
||||
error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR")
|
||||
await self._send_to_connection(connection_id, error_msg)
|
||||
return False
|
||||
|
||||
async def unsubscribe(self, connection_id: str, symbol: str,
|
||||
data_type: str = "heatmap") -> bool:
|
||||
"""
|
||||
Unsubscribe connection from symbol updates.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID
|
||||
symbol: Trading symbol
|
||||
data_type: Type of data to unsubscribe from
|
||||
|
||||
Returns:
|
||||
bool: True if unsubscribed successfully
|
||||
"""
|
||||
try:
|
||||
symbol = symbol.upper()
|
||||
subscription_key = f"{symbol}:{data_type}"
|
||||
|
||||
await self._unsubscribe_connection(connection_id, subscription_key)
|
||||
|
||||
# Send confirmation
|
||||
confirm_msg = self.response_formatter.success(
|
||||
data={'symbol': symbol, 'data_type': data_type},
|
||||
message=f"Unsubscribed from {symbol} {data_type} updates"
|
||||
)
|
||||
await self._send_to_connection(connection_id, confirm_msg)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}")
|
||||
return False
|
||||
|
||||
async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int:
|
||||
"""
|
||||
Broadcast data update to all subscribers.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
data_type: Type of data
|
||||
data: Data to broadcast
|
||||
|
||||
Returns:
|
||||
int: Number of connections notified
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
subscription_key = f"{symbol.upper()}:{data_type}"
|
||||
subscribers = self.subscriptions.get(subscription_key, set())
|
||||
|
||||
if not subscribers:
|
||||
return 0
|
||||
|
||||
# Format message based on data type
|
||||
if data_type == "heatmap":
|
||||
message = self.response_formatter.heatmap_response(data, symbol)
|
||||
elif data_type == "orderbook":
|
||||
message = self.response_formatter.orderbook_response(data, symbol, "consolidated")
|
||||
else:
|
||||
message = self.response_formatter.success(data, f"{data_type} update for {symbol}")
|
||||
|
||||
# Add update type to message
|
||||
message['update_type'] = data_type
|
||||
message['symbol'] = symbol
|
||||
|
||||
# Send to all subscribers
|
||||
sent_count = 0
|
||||
for connection_id in subscribers.copy(): # Copy to avoid modification during iteration
|
||||
if await self._send_to_connection(connection_id, message):
|
||||
sent_count += 1
|
||||
|
||||
logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections")
|
||||
return sent_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting update for {symbol}: {e}")
|
||||
return 0
|
||||
|
||||
async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Send message to specific connection.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID
|
||||
message: Message to send
|
||||
|
||||
Returns:
|
||||
bool: True if sent successfully
|
||||
"""
|
||||
try:
|
||||
if connection_id not in self.connections:
|
||||
return False
|
||||
|
||||
websocket = self.connections[connection_id]
|
||||
message_json = json.dumps(message, default=str)
|
||||
|
||||
await websocket.send_text(message_json)
|
||||
|
||||
# Update statistics
|
||||
if connection_id in self.connection_metadata:
|
||||
self.connection_metadata[connection_id]['messages_sent'] += 1
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending message to {connection_id}: {e}")
|
||||
# Remove broken connection
|
||||
await self.disconnect(connection_id)
|
||||
return False
|
||||
|
||||
async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None:
|
||||
"""Remove connection from subscription"""
|
||||
if subscription_key in self.subscriptions:
|
||||
self.subscriptions[subscription_key].discard(connection_id)
|
||||
|
||||
# Clean up empty subscriptions
|
||||
if not self.subscriptions[subscription_key]:
|
||||
del self.subscriptions[subscription_key]
|
||||
|
||||
# Update connection metadata
|
||||
if connection_id in self.connection_metadata:
|
||||
self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key)
|
||||
|
||||
async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None:
|
||||
"""Send initial data to newly subscribed connection"""
|
||||
try:
|
||||
if data_type == "heatmap":
|
||||
# Get latest heatmap from cache
|
||||
heatmap_data = await redis_manager.get_heatmap(symbol)
|
||||
if heatmap_data:
|
||||
message = self.response_formatter.heatmap_response(heatmap_data, symbol)
|
||||
message['update_type'] = 'initial_data'
|
||||
await self._send_to_connection(connection_id, message)
|
||||
|
||||
elif data_type == "orderbook":
|
||||
# Could get latest order book from cache
|
||||
# This would require knowing which exchange to get data from
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending initial data to {connection_id}: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get WebSocket manager statistics"""
|
||||
total_subscriptions = sum(len(subs) for subs in self.subscriptions.values())
|
||||
|
||||
return {
|
||||
'active_connections': len(self.connections),
|
||||
'total_subscriptions': total_subscriptions,
|
||||
'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())),
|
||||
'connection_counter': self.connection_counter
|
||||
}
|
||||
|
||||
|
||||
# Global WebSocket manager instance
|
||||
websocket_manager = WebSocketManager()
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
"""
|
||||
WebSocket server for real-time data streaming.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize WebSocket server"""
|
||||
self.manager = websocket_manager
|
||||
logger.info("WebSocket server initialized")
|
||||
|
||||
async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None:
|
||||
"""
|
||||
Handle WebSocket connection lifecycle.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
client_ip: Client IP address
|
||||
"""
|
||||
connection_id = None
|
||||
|
||||
try:
|
||||
# Accept connection
|
||||
connection_id = await self.manager.connect(websocket, client_ip)
|
||||
|
||||
# Handle messages
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
message = await websocket.receive_text()
|
||||
await self._handle_message(connection_id, message)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected: {connection_id}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection error: {e}")
|
||||
|
||||
finally:
|
||||
# Clean up connection
|
||||
if connection_id:
|
||||
await self.manager.disconnect(connection_id)
|
||||
|
||||
async def _handle_message(self, connection_id: str, message: str) -> None:
|
||||
"""
|
||||
Handle incoming WebSocket message.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID
|
||||
message: Received message
|
||||
"""
|
||||
try:
|
||||
# Parse message
|
||||
data = json.loads(message)
|
||||
action = data.get('action')
|
||||
|
||||
if action == 'subscribe':
|
||||
symbol = data.get('symbol')
|
||||
data_type = data.get('data_type', 'heatmap')
|
||||
await self.manager.subscribe(connection_id, symbol, data_type)
|
||||
|
||||
elif action == 'unsubscribe':
|
||||
symbol = data.get('symbol')
|
||||
data_type = data.get('data_type', 'heatmap')
|
||||
await self.manager.unsubscribe(connection_id, symbol, data_type)
|
||||
|
||||
elif action == 'ping':
|
||||
# Send pong response
|
||||
pong_msg = self.manager.response_formatter.success(
|
||||
data={'action': 'pong'},
|
||||
message="Pong"
|
||||
)
|
||||
await self.manager._send_to_connection(connection_id, pong_msg)
|
||||
|
||||
else:
|
||||
# Unknown action
|
||||
error_msg = self.manager.response_formatter.error(
|
||||
f"Unknown action: {action}",
|
||||
"UNKNOWN_ACTION"
|
||||
)
|
||||
await self.manager._send_to_connection(connection_id, error_msg)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
error_msg = self.manager.response_formatter.error(
|
||||
"Invalid JSON message",
|
||||
"INVALID_JSON"
|
||||
)
|
||||
await self.manager._send_to_connection(connection_id, error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling WebSocket message: {e}")
|
||||
error_msg = self.manager.response_formatter.error(
|
||||
"Message processing failed",
|
||||
"MESSAGE_ERROR"
|
||||
)
|
||||
await self.manager._send_to_connection(connection_id, error_msg)
|
Reference in New Issue
Block a user