""" Connection management with exponential backoff and retry logic. """ import asyncio import random from typing import Optional, Callable, Any from ..utils.logging import get_logger from ..utils.exceptions import ConnectionError logger = get_logger(__name__) class ExponentialBackoff: """Exponential backoff strategy for connection retries""" def __init__( self, initial_delay: float = 1.0, max_delay: float = 300.0, multiplier: float = 2.0, jitter: bool = True ): """ Initialize exponential backoff. Args: initial_delay: Initial delay in seconds max_delay: Maximum delay in seconds multiplier: Backoff multiplier jitter: Whether to add random jitter """ self.initial_delay = initial_delay self.max_delay = max_delay self.multiplier = multiplier self.jitter = jitter self.current_delay = initial_delay self.attempt_count = 0 def get_delay(self) -> float: """Get next delay value""" delay = min(self.current_delay, self.max_delay) # Add jitter to prevent thundering herd if self.jitter: delay = delay * (0.5 + random.random() * 0.5) # Update for next attempt self.current_delay *= self.multiplier self.attempt_count += 1 return delay def reset(self) -> None: """Reset backoff to initial state""" self.current_delay = self.initial_delay self.attempt_count = 0 class ConnectionManager: """ Manages connection lifecycle with retry logic and health monitoring. """ def __init__( self, name: str, max_retries: int = 10, initial_delay: float = 1.0, max_delay: float = 300.0, health_check_interval: int = 30 ): """ Initialize connection manager. Args: name: Connection name for logging max_retries: Maximum number of retry attempts initial_delay: Initial retry delay in seconds max_delay: Maximum retry delay in seconds health_check_interval: Health check interval in seconds """ self.name = name self.max_retries = max_retries self.health_check_interval = health_check_interval self.backoff = ExponentialBackoff(initial_delay, max_delay) self.is_connected = False self.connection_attempts = 0 self.last_error: Optional[Exception] = None self.health_check_task: Optional[asyncio.Task] = None # Callbacks self.on_connect: Optional[Callable] = None self.on_disconnect: Optional[Callable] = None self.on_error: Optional[Callable] = None self.on_health_check: Optional[Callable] = None logger.info(f"Connection manager '{name}' initialized") async def connect(self, connect_func: Callable) -> bool: """ Attempt to establish connection with retry logic. Args: connect_func: Async function that establishes the connection Returns: bool: True if connection successful, False otherwise """ self.connection_attempts = 0 self.backoff.reset() while self.connection_attempts < self.max_retries: try: logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})") # Attempt connection await connect_func() # Connection successful self.is_connected = True self.connection_attempts = 0 self.last_error = None self.backoff.reset() logger.info(f"Connection '{self.name}' established successfully") # Start health check await self._start_health_check() # Notify success if self.on_connect: try: await self.on_connect() except Exception as e: logger.warning(f"Error in connect callback: {e}") return True except Exception as e: self.connection_attempts += 1 self.last_error = e logger.warning( f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}" ) # Notify error if self.on_error: try: await self.on_error(e) except Exception as callback_error: logger.warning(f"Error in error callback: {callback_error}") # Check if we should retry if self.connection_attempts >= self.max_retries: logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts") break # Wait before retry delay = self.backoff.get_delay() logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds") await asyncio.sleep(delay) self.is_connected = False return False async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None: """ Disconnect and cleanup. Args: disconnect_func: Optional async function to handle disconnection """ logger.info(f"Disconnecting '{self.name}'") # Stop health check await self._stop_health_check() # Execute disconnect function if disconnect_func: try: await disconnect_func() except Exception as e: logger.warning(f"Error during disconnect: {e}") self.is_connected = False # Notify disconnect if self.on_disconnect: try: await self.on_disconnect() except Exception as e: logger.warning(f"Error in disconnect callback: {e}") logger.info(f"Connection '{self.name}' disconnected") async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool: """ Reconnect by disconnecting first then connecting. Args: connect_func: Async function that establishes the connection disconnect_func: Optional async function to handle disconnection Returns: bool: True if reconnection successful, False otherwise """ logger.info(f"Reconnecting '{self.name}'") # Disconnect first await self.disconnect(disconnect_func) # Wait a bit before reconnecting await asyncio.sleep(1.0) # Attempt to connect return await self.connect(connect_func) async def _start_health_check(self) -> None: """Start periodic health check""" if self.health_check_task: return self.health_check_task = asyncio.create_task(self._health_check_loop()) logger.debug(f"Health check started for '{self.name}'") async def _stop_health_check(self) -> None: """Stop health check""" if self.health_check_task: self.health_check_task.cancel() try: await self.health_check_task except asyncio.CancelledError: pass self.health_check_task = None logger.debug(f"Health check stopped for '{self.name}'") async def _health_check_loop(self) -> None: """Health check loop""" while self.is_connected: try: await asyncio.sleep(self.health_check_interval) if self.on_health_check: is_healthy = await self.on_health_check() if not is_healthy: logger.warning(f"Health check failed for '{self.name}'") self.is_connected = False break except asyncio.CancelledError: break except Exception as e: logger.error(f"Health check error for '{self.name}': {e}") self.is_connected = False break def get_stats(self) -> dict: """Get connection statistics""" return { 'name': self.name, 'is_connected': self.is_connected, 'connection_attempts': self.connection_attempts, 'max_retries': self.max_retries, 'current_delay': self.backoff.current_delay, 'backoff_attempts': self.backoff.attempt_count, 'last_error': str(self.last_error) if self.last_error else None, 'health_check_active': self.health_check_task is not None }