Files
gogo2/COBY/connectors/base_connector.py
2025-08-04 17:12:26 +03:00

383 lines
15 KiB
Python

"""
Base exchange connector implementation with connection management and error handling.
"""
import asyncio
import json
import websockets
from typing import Dict, List, Optional, Callable, Any
from datetime import datetime, timezone
from ..interfaces.exchange_connector import ExchangeConnector
from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ConnectionError, ValidationError
from ..utils.timing import get_current_timestamp
from .connection_manager import ConnectionManager
from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError
logger = get_logger(__name__)
class BaseExchangeConnector(ExchangeConnector):
"""
Base implementation of exchange connector with common functionality.
Provides:
- WebSocket connection management
- Exponential backoff retry logic
- Circuit breaker pattern
- Health monitoring
- Message handling framework
- Subscription management
"""
def __init__(self, exchange_name: str, websocket_url: str):
"""
Initialize base exchange connector.
Args:
exchange_name: Name of the exchange
websocket_url: WebSocket URL for the exchange
"""
super().__init__(exchange_name)
self.websocket_url = websocket_url
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types]
self.message_handlers: Dict[str, Callable] = {}
# Connection management
self.connection_manager = ConnectionManager(
name=f"{exchange_name}_connector",
max_retries=10,
initial_delay=1.0,
max_delay=300.0,
health_check_interval=30
)
# Circuit breaker
self.circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=60,
expected_exception=Exception,
name=f"{exchange_name}_circuit"
)
# Statistics
self.message_count = 0
self.error_count = 0
self.last_message_time: Optional[datetime] = None
# Setup callbacks
self.connection_manager.on_connect = self._on_connect
self.connection_manager.on_disconnect = self._on_disconnect
self.connection_manager.on_error = self._on_error
self.connection_manager.on_health_check = self._health_check
# Message processing
self._message_queue = asyncio.Queue(maxsize=10000)
self._message_processor_task: Optional[asyncio.Task] = None
logger.info(f"Base connector initialized for {exchange_name}")
async def connect(self) -> bool:
"""Establish connection to the exchange WebSocket"""
try:
set_correlation_id()
logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}")
return await self.connection_manager.connect(self._establish_websocket_connection)
except Exception as e:
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
self._notify_status_callbacks(ConnectionStatus.ERROR)
return False
async def disconnect(self) -> None:
"""Disconnect from the exchange WebSocket"""
try:
set_correlation_id()
logger.info(f"Disconnecting from {self.exchange_name}")
await self.connection_manager.disconnect(self._close_websocket_connection)
except Exception as e:
logger.error(f"Error during disconnect from {self.exchange_name}: {e}")
async def _establish_websocket_connection(self) -> None:
"""Establish WebSocket connection"""
try:
# Use circuit breaker for connection
self.websocket = await self.circuit_breaker.call_async(
websockets.connect,
self.websocket_url,
ping_interval=20,
ping_timeout=10,
close_timeout=10
)
logger.info(f"WebSocket connected to {self.exchange_name}")
# Start message processing
await self._start_message_processing()
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker open for {self.exchange_name}: {e}")
raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN")
except Exception as e:
logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}")
raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED")
async def _close_websocket_connection(self) -> None:
"""Close WebSocket connection"""
try:
# Stop message processing
await self._stop_message_processing()
# Close WebSocket
if self.websocket:
await self.websocket.close()
self.websocket = None
logger.info(f"WebSocket disconnected from {self.exchange_name}")
except Exception as e:
logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}")
async def _start_message_processing(self) -> None:
"""Start message processing tasks"""
if self._message_processor_task:
return
# Start message processor
self._message_processor_task = asyncio.create_task(self._message_processor())
# Start message receiver
asyncio.create_task(self._message_receiver())
logger.debug(f"Message processing started for {self.exchange_name}")
async def _stop_message_processing(self) -> None:
"""Stop message processing tasks"""
if self._message_processor_task:
self._message_processor_task.cancel()
try:
await self._message_processor_task
except asyncio.CancelledError:
pass
self._message_processor_task = None
logger.debug(f"Message processing stopped for {self.exchange_name}")
async def _message_receiver(self) -> None:
"""Receive messages from WebSocket"""
try:
while self.websocket and not self.websocket.closed:
try:
message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0)
# Queue message for processing
try:
self._message_queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning(f"Message queue full for {self.exchange_name}, dropping message")
except asyncio.TimeoutError:
# Send ping to keep connection alive
if self.websocket:
await self.websocket.ping()
except websockets.exceptions.ConnectionClosed:
logger.warning(f"WebSocket connection closed for {self.exchange_name}")
break
except Exception as e:
logger.error(f"Error receiving message from {self.exchange_name}: {e}")
self.error_count += 1
break
except Exception as e:
logger.error(f"Message receiver error for {self.exchange_name}: {e}")
finally:
# Mark as disconnected
self.connection_manager.is_connected = False
async def _message_processor(self) -> None:
"""Process messages from the queue"""
while True:
try:
# Get message from queue
message = await self._message_queue.get()
# Process message
await self._process_message(message)
# Update statistics
self.message_count += 1
self.last_message_time = get_current_timestamp()
# Mark task as done
self._message_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error processing message for {self.exchange_name}: {e}")
self.error_count += 1
async def _process_message(self, message: str) -> None:
"""
Process incoming WebSocket message.
Args:
message: Raw message string
"""
try:
# Parse JSON message
data = json.loads(message)
# Determine message type and route to appropriate handler
message_type = self._get_message_type(data)
if message_type in self.message_handlers:
await self.message_handlers[message_type](data)
else:
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
except Exception as e:
logger.error(f"Error processing message from {self.exchange_name}: {e}")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from message data.
Override in subclasses for exchange-specific logic.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Default implementation - override in subclasses
return data.get('type', 'unknown')
async def _send_message(self, message: Dict) -> bool:
"""
Send message to WebSocket.
Args:
message: Message to send
Returns:
bool: True if sent successfully, False otherwise
"""
try:
if not self.websocket or self.websocket.closed:
logger.warning(f"Cannot send message to {self.exchange_name}: not connected")
return False
message_str = json.dumps(message)
await self.websocket.send(message_str)
logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...")
return True
except Exception as e:
logger.error(f"Error sending message to {self.exchange_name}: {e}")
return False
# Callback handlers
async def _on_connect(self) -> None:
"""Handle successful connection"""
self._notify_status_callbacks(ConnectionStatus.CONNECTED)
# Resubscribe to all previous subscriptions
await self._resubscribe_all()
async def _on_disconnect(self) -> None:
"""Handle disconnection"""
self._notify_status_callbacks(ConnectionStatus.DISCONNECTED)
async def _on_error(self, error: Exception) -> None:
"""Handle connection error"""
logger.error(f"Connection error for {self.exchange_name}: {error}")
self._notify_status_callbacks(ConnectionStatus.ERROR)
async def _health_check(self) -> bool:
"""Perform health check"""
try:
if not self.websocket or self.websocket.closed:
return False
# Check if we've received messages recently
if self.last_message_time:
time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds()
if time_since_last_message > 60: # No messages for 60 seconds
logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s")
return False
# Send ping
await self.websocket.ping()
return True
except Exception as e:
logger.error(f"Health check failed for {self.exchange_name}: {e}")
return False
async def _resubscribe_all(self) -> None:
"""Resubscribe to all previous subscriptions after reconnection"""
for symbol, subscription_types in self.subscriptions.items():
for sub_type in subscription_types:
try:
if sub_type == 'orderbook':
await self.subscribe_orderbook(symbol)
elif sub_type == 'trades':
await self.subscribe_trades(symbol)
except Exception as e:
logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}")
# Abstract methods that must be implemented by subclasses
async def subscribe_orderbook(self, symbol: str) -> None:
"""Subscribe to order book updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement subscribe_orderbook")
async def subscribe_trades(self, symbol: str) -> None:
"""Subscribe to trade updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement subscribe_trades")
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""Unsubscribe from order book updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement unsubscribe_orderbook")
async def unsubscribe_trades(self, symbol: str) -> None:
"""Unsubscribe from trade updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement unsubscribe_trades")
async def get_symbols(self) -> List[str]:
"""Get available symbols - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement get_symbols")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement normalize_symbol")
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""Get order book snapshot - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement get_orderbook_snapshot")
# Utility methods
def get_stats(self) -> Dict[str, Any]:
"""Get connector statistics"""
return {
'exchange': self.exchange_name,
'connection_status': self.get_connection_status().value,
'is_connected': self.is_connected,
'message_count': self.message_count,
'error_count': self.error_count,
'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None,
'subscriptions': dict(self.subscriptions),
'connection_manager': self.connection_manager.get_stats(),
'circuit_breaker': self.circuit_breaker.get_stats(),
'queue_size': self._message_queue.qsize()
}