cob integration scaffold
This commit is contained in:
13
COBY/connectors/__init__.py
Normal file
13
COBY/connectors/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Exchange connector implementations for the COBY system.
|
||||
"""
|
||||
|
||||
from .base_connector import BaseExchangeConnector
|
||||
from .connection_manager import ConnectionManager
|
||||
from .circuit_breaker import CircuitBreaker
|
||||
|
||||
__all__ = [
|
||||
'BaseExchangeConnector',
|
||||
'ConnectionManager',
|
||||
'CircuitBreaker'
|
||||
]
|
||||
383
COBY/connectors/base_connector.py
Normal file
383
COBY/connectors/base_connector.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Base exchange connector implementation with connection management and error handling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from ..interfaces.exchange_connector import ExchangeConnector
|
||||
from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.exceptions import ConnectionError, ValidationError
|
||||
from ..utils.timing import get_current_timestamp
|
||||
from .connection_manager import ConnectionManager
|
||||
from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseExchangeConnector(ExchangeConnector):
|
||||
"""
|
||||
Base implementation of exchange connector with common functionality.
|
||||
|
||||
Provides:
|
||||
- WebSocket connection management
|
||||
- Exponential backoff retry logic
|
||||
- Circuit breaker pattern
|
||||
- Health monitoring
|
||||
- Message handling framework
|
||||
- Subscription management
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_name: str, websocket_url: str):
|
||||
"""
|
||||
Initialize base exchange connector.
|
||||
|
||||
Args:
|
||||
exchange_name: Name of the exchange
|
||||
websocket_url: WebSocket URL for the exchange
|
||||
"""
|
||||
super().__init__(exchange_name)
|
||||
|
||||
self.websocket_url = websocket_url
|
||||
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
|
||||
self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types]
|
||||
self.message_handlers: Dict[str, Callable] = {}
|
||||
|
||||
# Connection management
|
||||
self.connection_manager = ConnectionManager(
|
||||
name=f"{exchange_name}_connector",
|
||||
max_retries=10,
|
||||
initial_delay=1.0,
|
||||
max_delay=300.0,
|
||||
health_check_interval=30
|
||||
)
|
||||
|
||||
# Circuit breaker
|
||||
self.circuit_breaker = CircuitBreaker(
|
||||
failure_threshold=5,
|
||||
recovery_timeout=60,
|
||||
expected_exception=Exception,
|
||||
name=f"{exchange_name}_circuit"
|
||||
)
|
||||
|
||||
# Statistics
|
||||
self.message_count = 0
|
||||
self.error_count = 0
|
||||
self.last_message_time: Optional[datetime] = None
|
||||
|
||||
# Setup callbacks
|
||||
self.connection_manager.on_connect = self._on_connect
|
||||
self.connection_manager.on_disconnect = self._on_disconnect
|
||||
self.connection_manager.on_error = self._on_error
|
||||
self.connection_manager.on_health_check = self._health_check
|
||||
|
||||
# Message processing
|
||||
self._message_queue = asyncio.Queue(maxsize=10000)
|
||||
self._message_processor_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info(f"Base connector initialized for {exchange_name}")
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Establish connection to the exchange WebSocket"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}")
|
||||
|
||||
return await self.connection_manager.connect(self._establish_websocket_connection)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
|
||||
self._notify_status_callbacks(ConnectionStatus.ERROR)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the exchange WebSocket"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
logger.info(f"Disconnecting from {self.exchange_name}")
|
||||
|
||||
await self.connection_manager.disconnect(self._close_websocket_connection)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect from {self.exchange_name}: {e}")
|
||||
|
||||
async def _establish_websocket_connection(self) -> None:
|
||||
"""Establish WebSocket connection"""
|
||||
try:
|
||||
# Use circuit breaker for connection
|
||||
self.websocket = await self.circuit_breaker.call_async(
|
||||
websockets.connect,
|
||||
self.websocket_url,
|
||||
ping_interval=20,
|
||||
ping_timeout=10,
|
||||
close_timeout=10
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket connected to {self.exchange_name}")
|
||||
|
||||
# Start message processing
|
||||
await self._start_message_processing()
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
logger.error(f"Circuit breaker open for {self.exchange_name}: {e}")
|
||||
raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}")
|
||||
raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED")
|
||||
|
||||
async def _close_websocket_connection(self) -> None:
|
||||
"""Close WebSocket connection"""
|
||||
try:
|
||||
# Stop message processing
|
||||
await self._stop_message_processing()
|
||||
|
||||
# Close WebSocket
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
logger.info(f"WebSocket disconnected from {self.exchange_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}")
|
||||
|
||||
async def _start_message_processing(self) -> None:
|
||||
"""Start message processing tasks"""
|
||||
if self._message_processor_task:
|
||||
return
|
||||
|
||||
# Start message processor
|
||||
self._message_processor_task = asyncio.create_task(self._message_processor())
|
||||
|
||||
# Start message receiver
|
||||
asyncio.create_task(self._message_receiver())
|
||||
|
||||
logger.debug(f"Message processing started for {self.exchange_name}")
|
||||
|
||||
async def _stop_message_processing(self) -> None:
|
||||
"""Stop message processing tasks"""
|
||||
if self._message_processor_task:
|
||||
self._message_processor_task.cancel()
|
||||
try:
|
||||
await self._message_processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._message_processor_task = None
|
||||
|
||||
logger.debug(f"Message processing stopped for {self.exchange_name}")
|
||||
|
||||
async def _message_receiver(self) -> None:
|
||||
"""Receive messages from WebSocket"""
|
||||
try:
|
||||
while self.websocket and not self.websocket.closed:
|
||||
try:
|
||||
message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0)
|
||||
|
||||
# Queue message for processing
|
||||
try:
|
||||
self._message_queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Message queue full for {self.exchange_name}, dropping message")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send ping to keep connection alive
|
||||
if self.websocket:
|
||||
await self.websocket.ping()
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning(f"WebSocket connection closed for {self.exchange_name}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving message from {self.exchange_name}: {e}")
|
||||
self.error_count += 1
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message receiver error for {self.exchange_name}: {e}")
|
||||
finally:
|
||||
# Mark as disconnected
|
||||
self.connection_manager.is_connected = False
|
||||
|
||||
async def _message_processor(self) -> None:
|
||||
"""Process messages from the queue"""
|
||||
while True:
|
||||
try:
|
||||
# Get message from queue
|
||||
message = await self._message_queue.get()
|
||||
|
||||
# Process message
|
||||
await self._process_message(message)
|
||||
|
||||
# Update statistics
|
||||
self.message_count += 1
|
||||
self.last_message_time = get_current_timestamp()
|
||||
|
||||
# Mark task as done
|
||||
self._message_queue.task_done()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message for {self.exchange_name}: {e}")
|
||||
self.error_count += 1
|
||||
|
||||
async def _process_message(self, message: str) -> None:
|
||||
"""
|
||||
Process incoming WebSocket message.
|
||||
|
||||
Args:
|
||||
message: Raw message string
|
||||
"""
|
||||
try:
|
||||
# Parse JSON message
|
||||
data = json.loads(message)
|
||||
|
||||
# Determine message type and route to appropriate handler
|
||||
message_type = self._get_message_type(data)
|
||||
|
||||
if message_type in self.message_handlers:
|
||||
await self.message_handlers[message_type](data)
|
||||
else:
|
||||
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message from {self.exchange_name}: {e}")
|
||||
|
||||
def _get_message_type(self, data: Dict) -> str:
|
||||
"""
|
||||
Determine message type from message data.
|
||||
Override in subclasses for exchange-specific logic.
|
||||
|
||||
Args:
|
||||
data: Parsed message data
|
||||
|
||||
Returns:
|
||||
str: Message type identifier
|
||||
"""
|
||||
# Default implementation - override in subclasses
|
||||
return data.get('type', 'unknown')
|
||||
|
||||
async def _send_message(self, message: Dict) -> bool:
|
||||
"""
|
||||
Send message to WebSocket.
|
||||
|
||||
Args:
|
||||
message: Message to send
|
||||
|
||||
Returns:
|
||||
bool: True if sent successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.websocket or self.websocket.closed:
|
||||
logger.warning(f"Cannot send message to {self.exchange_name}: not connected")
|
||||
return False
|
||||
|
||||
message_str = json.dumps(message)
|
||||
await self.websocket.send(message_str)
|
||||
|
||||
logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message to {self.exchange_name}: {e}")
|
||||
return False
|
||||
|
||||
# Callback handlers
|
||||
async def _on_connect(self) -> None:
|
||||
"""Handle successful connection"""
|
||||
self._notify_status_callbacks(ConnectionStatus.CONNECTED)
|
||||
|
||||
# Resubscribe to all previous subscriptions
|
||||
await self._resubscribe_all()
|
||||
|
||||
async def _on_disconnect(self) -> None:
|
||||
"""Handle disconnection"""
|
||||
self._notify_status_callbacks(ConnectionStatus.DISCONNECTED)
|
||||
|
||||
async def _on_error(self, error: Exception) -> None:
|
||||
"""Handle connection error"""
|
||||
logger.error(f"Connection error for {self.exchange_name}: {error}")
|
||||
self._notify_status_callbacks(ConnectionStatus.ERROR)
|
||||
|
||||
async def _health_check(self) -> bool:
|
||||
"""Perform health check"""
|
||||
try:
|
||||
if not self.websocket or self.websocket.closed:
|
||||
return False
|
||||
|
||||
# Check if we've received messages recently
|
||||
if self.last_message_time:
|
||||
time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds()
|
||||
if time_since_last_message > 60: # No messages for 60 seconds
|
||||
logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s")
|
||||
return False
|
||||
|
||||
# Send ping
|
||||
await self.websocket.ping()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed for {self.exchange_name}: {e}")
|
||||
return False
|
||||
|
||||
async def _resubscribe_all(self) -> None:
|
||||
"""Resubscribe to all previous subscriptions after reconnection"""
|
||||
for symbol, subscription_types in self.subscriptions.items():
|
||||
for sub_type in subscription_types:
|
||||
try:
|
||||
if sub_type == 'orderbook':
|
||||
await self.subscribe_orderbook(symbol)
|
||||
elif sub_type == 'trades':
|
||||
await self.subscribe_trades(symbol)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}")
|
||||
|
||||
# Abstract methods that must be implemented by subclasses
|
||||
async def subscribe_orderbook(self, symbol: str) -> None:
|
||||
"""Subscribe to order book updates - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement subscribe_orderbook")
|
||||
|
||||
async def subscribe_trades(self, symbol: str) -> None:
|
||||
"""Subscribe to trade updates - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement subscribe_trades")
|
||||
|
||||
async def unsubscribe_orderbook(self, symbol: str) -> None:
|
||||
"""Unsubscribe from order book updates - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement unsubscribe_orderbook")
|
||||
|
||||
async def unsubscribe_trades(self, symbol: str) -> None:
|
||||
"""Unsubscribe from trade updates - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement unsubscribe_trades")
|
||||
|
||||
async def get_symbols(self) -> List[str]:
|
||||
"""Get available symbols - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement get_symbols")
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol format - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement normalize_symbol")
|
||||
|
||||
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
|
||||
"""Get order book snapshot - must be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement get_orderbook_snapshot")
|
||||
|
||||
# Utility methods
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get connector statistics"""
|
||||
return {
|
||||
'exchange': self.exchange_name,
|
||||
'connection_status': self.get_connection_status().value,
|
||||
'is_connected': self.is_connected,
|
||||
'message_count': self.message_count,
|
||||
'error_count': self.error_count,
|
||||
'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None,
|
||||
'subscriptions': dict(self.subscriptions),
|
||||
'connection_manager': self.connection_manager.get_stats(),
|
||||
'circuit_breaker': self.circuit_breaker.get_stats(),
|
||||
'queue_size': self._message_queue.qsize()
|
||||
}
|
||||
206
COBY/connectors/circuit_breaker.py
Normal file
206
COBY/connectors/circuit_breaker.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Circuit breaker pattern implementation for exchange connections.
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable, Any
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Circuit is open, calls fail fast
|
||||
HALF_OPEN = "half_open" # Testing if service is back
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker to prevent cascading failures in exchange connections.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: Circuit is open, requests fail immediately
|
||||
- HALF_OPEN: Testing if service is back, limited requests allowed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
expected_exception: type = Exception,
|
||||
name: str = "CircuitBreaker"
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit
|
||||
recovery_timeout: Time in seconds before attempting recovery
|
||||
expected_exception: Exception type that triggers circuit breaker
|
||||
name: Name for logging purposes
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.expected_exception = expected_exception
|
||||
self.name = name
|
||||
|
||||
# State tracking
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._last_failure_time: Optional[float] = None
|
||||
self._next_attempt_time: Optional[float] = None
|
||||
|
||||
logger.info(f"Circuit breaker '{name}' initialized with threshold={failure_threshold}")
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
"""Get current failure count"""
|
||||
return self._failure_count
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if we should attempt to reset the circuit"""
|
||||
if self._state != CircuitState.OPEN:
|
||||
return False
|
||||
|
||||
if self._next_attempt_time is None:
|
||||
return False
|
||||
|
||||
return time.time() >= self._next_attempt_time
|
||||
|
||||
def _on_success(self) -> None:
|
||||
"""Handle successful operation"""
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
logger.info(f"Circuit breaker '{self.name}' reset to CLOSED after successful test")
|
||||
self._state = CircuitState.CLOSED
|
||||
|
||||
self._failure_count = 0
|
||||
self._last_failure_time = None
|
||||
self._next_attempt_time = None
|
||||
|
||||
def _on_failure(self) -> None:
|
||||
"""Handle failed operation"""
|
||||
self._failure_count += 1
|
||||
self._last_failure_time = time.time()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Failed during test, go back to OPEN
|
||||
logger.warning(f"Circuit breaker '{self.name}' failed during test, returning to OPEN")
|
||||
self._state = CircuitState.OPEN
|
||||
self._next_attempt_time = time.time() + self.recovery_timeout
|
||||
elif self._failure_count >= self.failure_threshold:
|
||||
# Too many failures, open the circuit
|
||||
logger.error(
|
||||
f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._next_attempt_time = time.time() + self.recovery_timeout
|
||||
|
||||
def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Function arguments
|
||||
**kwargs: Function keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpenError: When circuit is open
|
||||
Original exception: When function fails
|
||||
"""
|
||||
# Check if we should attempt reset
|
||||
if self._should_attempt_reset():
|
||||
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
|
||||
# Fail fast if circuit is open
|
||||
if self._state == CircuitState.OPEN:
|
||||
raise CircuitBreakerOpenError(
|
||||
f"Circuit breaker '{self.name}' is OPEN. "
|
||||
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
|
||||
except self.expected_exception as e:
|
||||
self._on_failure()
|
||||
raise e
|
||||
|
||||
async def call_async(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute async function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Function arguments
|
||||
**kwargs: Function keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpenError: When circuit is open
|
||||
Original exception: When function fails
|
||||
"""
|
||||
# Check if we should attempt reset
|
||||
if self._should_attempt_reset():
|
||||
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
|
||||
# Fail fast if circuit is open
|
||||
if self._state == CircuitState.OPEN:
|
||||
raise CircuitBreakerOpenError(
|
||||
f"Circuit breaker '{self.name}' is OPEN. "
|
||||
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute the async function
|
||||
result = await func(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
|
||||
except self.expected_exception as e:
|
||||
self._on_failure()
|
||||
raise e
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Manually reset the circuit breaker"""
|
||||
logger.info(f"Circuit breaker '{self.name}' manually reset")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._last_failure_time = None
|
||||
self._next_attempt_time = None
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get circuit breaker statistics"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'state': self._state.value,
|
||||
'failure_count': self._failure_count,
|
||||
'failure_threshold': self.failure_threshold,
|
||||
'last_failure_time': self._last_failure_time,
|
||||
'next_attempt_time': self._next_attempt_time,
|
||||
'recovery_timeout': self.recovery_timeout
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerOpenError(Exception):
|
||||
"""Exception raised when circuit breaker is open"""
|
||||
pass
|
||||
271
COBY/connectors/connection_manager.py
Normal file
271
COBY/connectors/connection_manager.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Connection management with exponential backoff and retry logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Optional, Callable, Any
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.exceptions import ConnectionError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ExponentialBackoff:
|
||||
"""Exponential backoff strategy for connection retries"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 300.0,
|
||||
multiplier: float = 2.0,
|
||||
jitter: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize exponential backoff.
|
||||
|
||||
Args:
|
||||
initial_delay: Initial delay in seconds
|
||||
max_delay: Maximum delay in seconds
|
||||
multiplier: Backoff multiplier
|
||||
jitter: Whether to add random jitter
|
||||
"""
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.multiplier = multiplier
|
||||
self.jitter = jitter
|
||||
self.current_delay = initial_delay
|
||||
self.attempt_count = 0
|
||||
|
||||
def get_delay(self) -> float:
|
||||
"""Get next delay value"""
|
||||
delay = min(self.current_delay, self.max_delay)
|
||||
|
||||
# Add jitter to prevent thundering herd
|
||||
if self.jitter:
|
||||
delay = delay * (0.5 + random.random() * 0.5)
|
||||
|
||||
# Update for next attempt
|
||||
self.current_delay *= self.multiplier
|
||||
self.attempt_count += 1
|
||||
|
||||
return delay
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset backoff to initial state"""
|
||||
self.current_delay = self.initial_delay
|
||||
self.attempt_count = 0
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""
|
||||
Manages connection lifecycle with retry logic and health monitoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
max_retries: int = 10,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 300.0,
|
||||
health_check_interval: int = 30
|
||||
):
|
||||
"""
|
||||
Initialize connection manager.
|
||||
|
||||
Args:
|
||||
name: Connection name for logging
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_delay: Initial retry delay in seconds
|
||||
max_delay: Maximum retry delay in seconds
|
||||
health_check_interval: Health check interval in seconds
|
||||
"""
|
||||
self.name = name
|
||||
self.max_retries = max_retries
|
||||
self.health_check_interval = health_check_interval
|
||||
|
||||
self.backoff = ExponentialBackoff(initial_delay, max_delay)
|
||||
self.is_connected = False
|
||||
self.connection_attempts = 0
|
||||
self.last_error: Optional[Exception] = None
|
||||
self.health_check_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Callbacks
|
||||
self.on_connect: Optional[Callable] = None
|
||||
self.on_disconnect: Optional[Callable] = None
|
||||
self.on_error: Optional[Callable] = None
|
||||
self.on_health_check: Optional[Callable] = None
|
||||
|
||||
logger.info(f"Connection manager '{name}' initialized")
|
||||
|
||||
async def connect(self, connect_func: Callable) -> bool:
|
||||
"""
|
||||
Attempt to establish connection with retry logic.
|
||||
|
||||
Args:
|
||||
connect_func: Async function that establishes the connection
|
||||
|
||||
Returns:
|
||||
bool: True if connection successful, False otherwise
|
||||
"""
|
||||
self.connection_attempts = 0
|
||||
self.backoff.reset()
|
||||
|
||||
while self.connection_attempts < self.max_retries:
|
||||
try:
|
||||
logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})")
|
||||
|
||||
# Attempt connection
|
||||
await connect_func()
|
||||
|
||||
# Connection successful
|
||||
self.is_connected = True
|
||||
self.connection_attempts = 0
|
||||
self.last_error = None
|
||||
self.backoff.reset()
|
||||
|
||||
logger.info(f"Connection '{self.name}' established successfully")
|
||||
|
||||
# Start health check
|
||||
await self._start_health_check()
|
||||
|
||||
# Notify success
|
||||
if self.on_connect:
|
||||
try:
|
||||
await self.on_connect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in connect callback: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.connection_attempts += 1
|
||||
self.last_error = e
|
||||
|
||||
logger.warning(
|
||||
f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}"
|
||||
)
|
||||
|
||||
# Notify error
|
||||
if self.on_error:
|
||||
try:
|
||||
await self.on_error(e)
|
||||
except Exception as callback_error:
|
||||
logger.warning(f"Error in error callback: {callback_error}")
|
||||
|
||||
# Check if we should retry
|
||||
if self.connection_attempts >= self.max_retries:
|
||||
logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts")
|
||||
break
|
||||
|
||||
# Wait before retry
|
||||
delay = self.backoff.get_delay()
|
||||
logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None:
|
||||
"""
|
||||
Disconnect and cleanup.
|
||||
|
||||
Args:
|
||||
disconnect_func: Optional async function to handle disconnection
|
||||
"""
|
||||
logger.info(f"Disconnecting '{self.name}'")
|
||||
|
||||
# Stop health check
|
||||
await self._stop_health_check()
|
||||
|
||||
# Execute disconnect function
|
||||
if disconnect_func:
|
||||
try:
|
||||
await disconnect_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during disconnect: {e}")
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
# Notify disconnect
|
||||
if self.on_disconnect:
|
||||
try:
|
||||
await self.on_disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in disconnect callback: {e}")
|
||||
|
||||
logger.info(f"Connection '{self.name}' disconnected")
|
||||
|
||||
async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool:
|
||||
"""
|
||||
Reconnect by disconnecting first then connecting.
|
||||
|
||||
Args:
|
||||
connect_func: Async function that establishes the connection
|
||||
disconnect_func: Optional async function to handle disconnection
|
||||
|
||||
Returns:
|
||||
bool: True if reconnection successful, False otherwise
|
||||
"""
|
||||
logger.info(f"Reconnecting '{self.name}'")
|
||||
|
||||
# Disconnect first
|
||||
await self.disconnect(disconnect_func)
|
||||
|
||||
# Wait a bit before reconnecting
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Attempt to connect
|
||||
return await self.connect(connect_func)
|
||||
|
||||
async def _start_health_check(self) -> None:
|
||||
"""Start periodic health check"""
|
||||
if self.health_check_task:
|
||||
return
|
||||
|
||||
self.health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
logger.debug(f"Health check started for '{self.name}'")
|
||||
|
||||
async def _stop_health_check(self) -> None:
|
||||
"""Stop health check"""
|
||||
if self.health_check_task:
|
||||
self.health_check_task.cancel()
|
||||
try:
|
||||
await self.health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.health_check_task = None
|
||||
logger.debug(f"Health check stopped for '{self.name}'")
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Health check loop"""
|
||||
while self.is_connected:
|
||||
try:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self.on_health_check:
|
||||
is_healthy = await self.on_health_check()
|
||||
if not is_healthy:
|
||||
logger.warning(f"Health check failed for '{self.name}'")
|
||||
self.is_connected = False
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error for '{self.name}': {e}")
|
||||
self.is_connected = False
|
||||
break
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get connection statistics"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'is_connected': self.is_connected,
|
||||
'connection_attempts': self.connection_attempts,
|
||||
'max_retries': self.max_retries,
|
||||
'current_delay': self.backoff.current_delay,
|
||||
'backoff_attempts': self.backoff.attempt_count,
|
||||
'last_error': str(self.last_error) if self.last_error else None,
|
||||
'health_check_active': self.health_check_task is not None
|
||||
}
|
||||
Reference in New Issue
Block a user