""" Base exchange connector implementation with connection management and error handling. """ import asyncio import json import websockets from typing import Dict, List, Optional, Callable, Any from datetime import datetime, timezone from ..interfaces.exchange_connector import ExchangeConnector from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent from ..utils.logging import get_logger, set_correlation_id from ..utils.exceptions import ConnectionError, ValidationError from ..utils.timing import get_current_timestamp from .connection_manager import ConnectionManager from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError logger = get_logger(__name__) class BaseExchangeConnector(ExchangeConnector): """ Base implementation of exchange connector with common functionality. Provides: - WebSocket connection management - Exponential backoff retry logic - Circuit breaker pattern - Health monitoring - Message handling framework - Subscription management """ def __init__(self, exchange_name: str, websocket_url: str): """ Initialize base exchange connector. Args: exchange_name: Name of the exchange websocket_url: WebSocket URL for the exchange """ super().__init__(exchange_name) self.websocket_url = websocket_url self.websocket: Optional[websockets.WebSocketServerProtocol] = None self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types] self.message_handlers: Dict[str, Callable] = {} # Connection management self.connection_manager = ConnectionManager( name=f"{exchange_name}_connector", max_retries=10, initial_delay=1.0, max_delay=300.0, health_check_interval=30 ) # Circuit breaker self.circuit_breaker = CircuitBreaker( failure_threshold=5, recovery_timeout=60, expected_exception=Exception, name=f"{exchange_name}_circuit" ) # Statistics self.message_count = 0 self.error_count = 0 self.last_message_time: Optional[datetime] = None # Setup callbacks self.connection_manager.on_connect = self._on_connect self.connection_manager.on_disconnect = self._on_disconnect self.connection_manager.on_error = self._on_error self.connection_manager.on_health_check = self._health_check # Message processing self._message_queue = asyncio.Queue(maxsize=10000) self._message_processor_task: Optional[asyncio.Task] = None logger.info(f"Base connector initialized for {exchange_name}") async def connect(self) -> bool: """Establish connection to the exchange WebSocket""" try: set_correlation_id() logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}") return await self.connection_manager.connect(self._establish_websocket_connection) except Exception as e: logger.error(f"Failed to connect to {self.exchange_name}: {e}") self._notify_status_callbacks(ConnectionStatus.ERROR) return False async def disconnect(self) -> None: """Disconnect from the exchange WebSocket""" try: set_correlation_id() logger.info(f"Disconnecting from {self.exchange_name}") await self.connection_manager.disconnect(self._close_websocket_connection) except Exception as e: logger.error(f"Error during disconnect from {self.exchange_name}: {e}") async def _establish_websocket_connection(self) -> None: """Establish WebSocket connection""" try: # Use circuit breaker for connection self.websocket = await self.circuit_breaker.call_async( websockets.connect, self.websocket_url, ping_interval=20, ping_timeout=10, close_timeout=10 ) logger.info(f"WebSocket connected to {self.exchange_name}") # Start message processing await self._start_message_processing() except CircuitBreakerOpenError as e: logger.error(f"Circuit breaker open for {self.exchange_name}: {e}") raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN") except Exception as e: logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}") raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED") async def _close_websocket_connection(self) -> None: """Close WebSocket connection""" try: # Stop message processing await self._stop_message_processing() # Close WebSocket if self.websocket: await self.websocket.close() self.websocket = None logger.info(f"WebSocket disconnected from {self.exchange_name}") except Exception as e: logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}") async def _start_message_processing(self) -> None: """Start message processing tasks""" if self._message_processor_task: return # Start message processor self._message_processor_task = asyncio.create_task(self._message_processor()) # Start message receiver asyncio.create_task(self._message_receiver()) logger.debug(f"Message processing started for {self.exchange_name}") async def _stop_message_processing(self) -> None: """Stop message processing tasks""" if self._message_processor_task: self._message_processor_task.cancel() try: await self._message_processor_task except asyncio.CancelledError: pass self._message_processor_task = None logger.debug(f"Message processing stopped for {self.exchange_name}") async def _message_receiver(self) -> None: """Receive messages from WebSocket""" try: while self.websocket and not self.websocket.closed: try: message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0) # Queue message for processing try: self._message_queue.put_nowait(message) except asyncio.QueueFull: logger.warning(f"Message queue full for {self.exchange_name}, dropping message") except asyncio.TimeoutError: # Send ping to keep connection alive if self.websocket: await self.websocket.ping() except websockets.exceptions.ConnectionClosed: logger.warning(f"WebSocket connection closed for {self.exchange_name}") break except Exception as e: logger.error(f"Error receiving message from {self.exchange_name}: {e}") self.error_count += 1 break except Exception as e: logger.error(f"Message receiver error for {self.exchange_name}: {e}") finally: # Mark as disconnected self.connection_manager.is_connected = False async def _message_processor(self) -> None: """Process messages from the queue""" while True: try: # Get message from queue message = await self._message_queue.get() # Process message await self._process_message(message) # Update statistics self.message_count += 1 self.last_message_time = get_current_timestamp() # Mark task as done self._message_queue.task_done() except asyncio.CancelledError: break except Exception as e: logger.error(f"Error processing message for {self.exchange_name}: {e}") self.error_count += 1 async def _process_message(self, message: str) -> None: """ Process incoming WebSocket message. Args: message: Raw message string """ try: # Parse JSON message data = json.loads(message) # Determine message type and route to appropriate handler message_type = self._get_message_type(data) if message_type in self.message_handlers: await self.message_handlers[message_type](data) else: logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}") except json.JSONDecodeError as e: logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}") except Exception as e: logger.error(f"Error processing message from {self.exchange_name}: {e}") def _get_message_type(self, data: Dict) -> str: """ Determine message type from message data. Override in subclasses for exchange-specific logic. Args: data: Parsed message data Returns: str: Message type identifier """ # Default implementation - override in subclasses return data.get('type', 'unknown') async def _send_message(self, message: Dict) -> bool: """ Send message to WebSocket. Args: message: Message to send Returns: bool: True if sent successfully, False otherwise """ try: if not self.websocket or self.websocket.closed: logger.warning(f"Cannot send message to {self.exchange_name}: not connected") return False message_str = json.dumps(message) await self.websocket.send(message_str) logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...") return True except Exception as e: logger.error(f"Error sending message to {self.exchange_name}: {e}") return False # Callback handlers async def _on_connect(self) -> None: """Handle successful connection""" self._notify_status_callbacks(ConnectionStatus.CONNECTED) # Resubscribe to all previous subscriptions await self._resubscribe_all() async def _on_disconnect(self) -> None: """Handle disconnection""" self._notify_status_callbacks(ConnectionStatus.DISCONNECTED) async def _on_error(self, error: Exception) -> None: """Handle connection error""" logger.error(f"Connection error for {self.exchange_name}: {error}") self._notify_status_callbacks(ConnectionStatus.ERROR) async def _health_check(self) -> bool: """Perform health check""" try: if not self.websocket or self.websocket.closed: return False # Check if we've received messages recently if self.last_message_time: time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds() if time_since_last_message > 60: # No messages for 60 seconds logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s") return False # Send ping await self.websocket.ping() return True except Exception as e: logger.error(f"Health check failed for {self.exchange_name}: {e}") return False async def _resubscribe_all(self) -> None: """Resubscribe to all previous subscriptions after reconnection""" for symbol, subscription_types in self.subscriptions.items(): for sub_type in subscription_types: try: if sub_type == 'orderbook': await self.subscribe_orderbook(symbol) elif sub_type == 'trades': await self.subscribe_trades(symbol) except Exception as e: logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}") # Abstract methods that must be implemented by subclasses async def subscribe_orderbook(self, symbol: str) -> None: """Subscribe to order book updates - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement subscribe_orderbook") async def subscribe_trades(self, symbol: str) -> None: """Subscribe to trade updates - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement subscribe_trades") async def unsubscribe_orderbook(self, symbol: str) -> None: """Unsubscribe from order book updates - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement unsubscribe_orderbook") async def unsubscribe_trades(self, symbol: str) -> None: """Unsubscribe from trade updates - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement unsubscribe_trades") async def get_symbols(self) -> List[str]: """Get available symbols - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement get_symbols") def normalize_symbol(self, symbol: str) -> str: """Normalize symbol format - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement normalize_symbol") async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]: """Get order book snapshot - must be implemented by subclasses""" raise NotImplementedError("Subclasses must implement get_orderbook_snapshot") # Utility methods def get_stats(self) -> Dict[str, Any]: """Get connector statistics""" return { 'exchange': self.exchange_name, 'connection_status': self.get_connection_status().value, 'is_connected': self.is_connected, 'message_count': self.message_count, 'error_count': self.error_count, 'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None, 'subscriptions': dict(self.subscriptions), 'connection_manager': self.connection_manager.get_stats(), 'circuit_breaker': self.circuit_breaker.get_stats(), 'queue_size': self._message_queue.qsize() }