383 lines
15 KiB
Python
383 lines
15 KiB
Python
"""
|
|
Base exchange connector implementation with connection management and error handling.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import websockets
|
|
from typing import Dict, List, Optional, Callable, Any
|
|
from datetime import datetime, timezone
|
|
|
|
from ..interfaces.exchange_connector import ExchangeConnector
|
|
from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
|
|
from ..utils.logging import get_logger, set_correlation_id
|
|
from ..utils.exceptions import ConnectionError, ValidationError
|
|
from ..utils.timing import get_current_timestamp
|
|
from .connection_manager import ConnectionManager
|
|
from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BaseExchangeConnector(ExchangeConnector):
|
|
"""
|
|
Base implementation of exchange connector with common functionality.
|
|
|
|
Provides:
|
|
- WebSocket connection management
|
|
- Exponential backoff retry logic
|
|
- Circuit breaker pattern
|
|
- Health monitoring
|
|
- Message handling framework
|
|
- Subscription management
|
|
"""
|
|
|
|
def __init__(self, exchange_name: str, websocket_url: str):
|
|
"""
|
|
Initialize base exchange connector.
|
|
|
|
Args:
|
|
exchange_name: Name of the exchange
|
|
websocket_url: WebSocket URL for the exchange
|
|
"""
|
|
super().__init__(exchange_name)
|
|
|
|
self.websocket_url = websocket_url
|
|
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
|
|
self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types]
|
|
self.message_handlers: Dict[str, Callable] = {}
|
|
|
|
# Connection management
|
|
self.connection_manager = ConnectionManager(
|
|
name=f"{exchange_name}_connector",
|
|
max_retries=10,
|
|
initial_delay=1.0,
|
|
max_delay=300.0,
|
|
health_check_interval=30
|
|
)
|
|
|
|
# Circuit breaker
|
|
self.circuit_breaker = CircuitBreaker(
|
|
failure_threshold=5,
|
|
recovery_timeout=60,
|
|
expected_exception=Exception,
|
|
name=f"{exchange_name}_circuit"
|
|
)
|
|
|
|
# Statistics
|
|
self.message_count = 0
|
|
self.error_count = 0
|
|
self.last_message_time: Optional[datetime] = None
|
|
|
|
# Setup callbacks
|
|
self.connection_manager.on_connect = self._on_connect
|
|
self.connection_manager.on_disconnect = self._on_disconnect
|
|
self.connection_manager.on_error = self._on_error
|
|
self.connection_manager.on_health_check = self._health_check
|
|
|
|
# Message processing
|
|
self._message_queue = asyncio.Queue(maxsize=10000)
|
|
self._message_processor_task: Optional[asyncio.Task] = None
|
|
|
|
logger.info(f"Base connector initialized for {exchange_name}")
|
|
|
|
async def connect(self) -> bool:
|
|
"""Establish connection to the exchange WebSocket"""
|
|
try:
|
|
set_correlation_id()
|
|
logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}")
|
|
|
|
return await self.connection_manager.connect(self._establish_websocket_connection)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
|
|
self._notify_status_callbacks(ConnectionStatus.ERROR)
|
|
return False
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the exchange WebSocket"""
|
|
try:
|
|
set_correlation_id()
|
|
logger.info(f"Disconnecting from {self.exchange_name}")
|
|
|
|
await self.connection_manager.disconnect(self._close_websocket_connection)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during disconnect from {self.exchange_name}: {e}")
|
|
|
|
async def _establish_websocket_connection(self) -> None:
|
|
"""Establish WebSocket connection"""
|
|
try:
|
|
# Use circuit breaker for connection
|
|
self.websocket = await self.circuit_breaker.call_async(
|
|
websockets.connect,
|
|
self.websocket_url,
|
|
ping_interval=20,
|
|
ping_timeout=10,
|
|
close_timeout=10
|
|
)
|
|
|
|
logger.info(f"WebSocket connected to {self.exchange_name}")
|
|
|
|
# Start message processing
|
|
await self._start_message_processing()
|
|
|
|
except CircuitBreakerOpenError as e:
|
|
logger.error(f"Circuit breaker open for {self.exchange_name}: {e}")
|
|
raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}")
|
|
raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED")
|
|
|
|
async def _close_websocket_connection(self) -> None:
|
|
"""Close WebSocket connection"""
|
|
try:
|
|
# Stop message processing
|
|
await self._stop_message_processing()
|
|
|
|
# Close WebSocket
|
|
if self.websocket:
|
|
await self.websocket.close()
|
|
self.websocket = None
|
|
|
|
logger.info(f"WebSocket disconnected from {self.exchange_name}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}")
|
|
|
|
async def _start_message_processing(self) -> None:
|
|
"""Start message processing tasks"""
|
|
if self._message_processor_task:
|
|
return
|
|
|
|
# Start message processor
|
|
self._message_processor_task = asyncio.create_task(self._message_processor())
|
|
|
|
# Start message receiver
|
|
asyncio.create_task(self._message_receiver())
|
|
|
|
logger.debug(f"Message processing started for {self.exchange_name}")
|
|
|
|
async def _stop_message_processing(self) -> None:
|
|
"""Stop message processing tasks"""
|
|
if self._message_processor_task:
|
|
self._message_processor_task.cancel()
|
|
try:
|
|
await self._message_processor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._message_processor_task = None
|
|
|
|
logger.debug(f"Message processing stopped for {self.exchange_name}")
|
|
|
|
async def _message_receiver(self) -> None:
|
|
"""Receive messages from WebSocket"""
|
|
try:
|
|
while self.websocket and not self.websocket.closed:
|
|
try:
|
|
message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0)
|
|
|
|
# Queue message for processing
|
|
try:
|
|
self._message_queue.put_nowait(message)
|
|
except asyncio.QueueFull:
|
|
logger.warning(f"Message queue full for {self.exchange_name}, dropping message")
|
|
|
|
except asyncio.TimeoutError:
|
|
# Send ping to keep connection alive
|
|
if self.websocket:
|
|
await self.websocket.ping()
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.warning(f"WebSocket connection closed for {self.exchange_name}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error receiving message from {self.exchange_name}: {e}")
|
|
self.error_count += 1
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Message receiver error for {self.exchange_name}: {e}")
|
|
finally:
|
|
# Mark as disconnected
|
|
self.connection_manager.is_connected = False
|
|
|
|
async def _message_processor(self) -> None:
|
|
"""Process messages from the queue"""
|
|
while True:
|
|
try:
|
|
# Get message from queue
|
|
message = await self._message_queue.get()
|
|
|
|
# Process message
|
|
await self._process_message(message)
|
|
|
|
# Update statistics
|
|
self.message_count += 1
|
|
self.last_message_time = get_current_timestamp()
|
|
|
|
# Mark task as done
|
|
self._message_queue.task_done()
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error processing message for {self.exchange_name}: {e}")
|
|
self.error_count += 1
|
|
|
|
async def _process_message(self, message: str) -> None:
|
|
"""
|
|
Process incoming WebSocket message.
|
|
|
|
Args:
|
|
message: Raw message string
|
|
"""
|
|
try:
|
|
# Parse JSON message
|
|
data = json.loads(message)
|
|
|
|
# Determine message type and route to appropriate handler
|
|
message_type = self._get_message_type(data)
|
|
|
|
if message_type in self.message_handlers:
|
|
await self.message_handlers[message_type](data)
|
|
else:
|
|
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error processing message from {self.exchange_name}: {e}")
|
|
|
|
def _get_message_type(self, data: Dict) -> str:
|
|
"""
|
|
Determine message type from message data.
|
|
Override in subclasses for exchange-specific logic.
|
|
|
|
Args:
|
|
data: Parsed message data
|
|
|
|
Returns:
|
|
str: Message type identifier
|
|
"""
|
|
# Default implementation - override in subclasses
|
|
return data.get('type', 'unknown')
|
|
|
|
async def _send_message(self, message: Dict) -> bool:
|
|
"""
|
|
Send message to WebSocket.
|
|
|
|
Args:
|
|
message: Message to send
|
|
|
|
Returns:
|
|
bool: True if sent successfully, False otherwise
|
|
"""
|
|
try:
|
|
if not self.websocket or self.websocket.closed:
|
|
logger.warning(f"Cannot send message to {self.exchange_name}: not connected")
|
|
return False
|
|
|
|
message_str = json.dumps(message)
|
|
await self.websocket.send(message_str)
|
|
|
|
logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending message to {self.exchange_name}: {e}")
|
|
return False
|
|
|
|
# Callback handlers
|
|
async def _on_connect(self) -> None:
|
|
"""Handle successful connection"""
|
|
self._notify_status_callbacks(ConnectionStatus.CONNECTED)
|
|
|
|
# Resubscribe to all previous subscriptions
|
|
await self._resubscribe_all()
|
|
|
|
async def _on_disconnect(self) -> None:
|
|
"""Handle disconnection"""
|
|
self._notify_status_callbacks(ConnectionStatus.DISCONNECTED)
|
|
|
|
async def _on_error(self, error: Exception) -> None:
|
|
"""Handle connection error"""
|
|
logger.error(f"Connection error for {self.exchange_name}: {error}")
|
|
self._notify_status_callbacks(ConnectionStatus.ERROR)
|
|
|
|
async def _health_check(self) -> bool:
|
|
"""Perform health check"""
|
|
try:
|
|
if not self.websocket or self.websocket.closed:
|
|
return False
|
|
|
|
# Check if we've received messages recently
|
|
if self.last_message_time:
|
|
time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds()
|
|
if time_since_last_message > 60: # No messages for 60 seconds
|
|
logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s")
|
|
return False
|
|
|
|
# Send ping
|
|
await self.websocket.ping()
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Health check failed for {self.exchange_name}: {e}")
|
|
return False
|
|
|
|
async def _resubscribe_all(self) -> None:
|
|
"""Resubscribe to all previous subscriptions after reconnection"""
|
|
for symbol, subscription_types in self.subscriptions.items():
|
|
for sub_type in subscription_types:
|
|
try:
|
|
if sub_type == 'orderbook':
|
|
await self.subscribe_orderbook(symbol)
|
|
elif sub_type == 'trades':
|
|
await self.subscribe_trades(symbol)
|
|
except Exception as e:
|
|
logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}")
|
|
|
|
# Abstract methods that must be implemented by subclasses
|
|
async def subscribe_orderbook(self, symbol: str) -> None:
|
|
"""Subscribe to order book updates - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement subscribe_orderbook")
|
|
|
|
async def subscribe_trades(self, symbol: str) -> None:
|
|
"""Subscribe to trade updates - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement subscribe_trades")
|
|
|
|
async def unsubscribe_orderbook(self, symbol: str) -> None:
|
|
"""Unsubscribe from order book updates - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement unsubscribe_orderbook")
|
|
|
|
async def unsubscribe_trades(self, symbol: str) -> None:
|
|
"""Unsubscribe from trade updates - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement unsubscribe_trades")
|
|
|
|
async def get_symbols(self) -> List[str]:
|
|
"""Get available symbols - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement get_symbols")
|
|
|
|
def normalize_symbol(self, symbol: str) -> str:
|
|
"""Normalize symbol format - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement normalize_symbol")
|
|
|
|
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
|
|
"""Get order book snapshot - must be implemented by subclasses"""
|
|
raise NotImplementedError("Subclasses must implement get_orderbook_snapshot")
|
|
|
|
# Utility methods
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get connector statistics"""
|
|
return {
|
|
'exchange': self.exchange_name,
|
|
'connection_status': self.get_connection_status().value,
|
|
'is_connected': self.is_connected,
|
|
'message_count': self.message_count,
|
|
'error_count': self.error_count,
|
|
'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None,
|
|
'subscriptions': dict(self.subscriptions),
|
|
'connection_manager': self.connection_manager.get_stats(),
|
|
'circuit_breaker': self.circuit_breaker.get_stats(),
|
|
'queue_size': self._message_queue.qsize()
|
|
} |