271 lines
9.1 KiB
Python
271 lines
9.1 KiB
Python
"""
|
|
Connection management with exponential backoff and retry logic.
|
|
"""
|
|
|
|
import asyncio
|
|
import random
|
|
from typing import Optional, Callable, Any
|
|
from ..utils.logging import get_logger
|
|
from ..utils.exceptions import ConnectionError
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ExponentialBackoff:
|
|
"""Exponential backoff strategy for connection retries"""
|
|
|
|
def __init__(
|
|
self,
|
|
initial_delay: float = 1.0,
|
|
max_delay: float = 300.0,
|
|
multiplier: float = 2.0,
|
|
jitter: bool = True
|
|
):
|
|
"""
|
|
Initialize exponential backoff.
|
|
|
|
Args:
|
|
initial_delay: Initial delay in seconds
|
|
max_delay: Maximum delay in seconds
|
|
multiplier: Backoff multiplier
|
|
jitter: Whether to add random jitter
|
|
"""
|
|
self.initial_delay = initial_delay
|
|
self.max_delay = max_delay
|
|
self.multiplier = multiplier
|
|
self.jitter = jitter
|
|
self.current_delay = initial_delay
|
|
self.attempt_count = 0
|
|
|
|
def get_delay(self) -> float:
|
|
"""Get next delay value"""
|
|
delay = min(self.current_delay, self.max_delay)
|
|
|
|
# Add jitter to prevent thundering herd
|
|
if self.jitter:
|
|
delay = delay * (0.5 + random.random() * 0.5)
|
|
|
|
# Update for next attempt
|
|
self.current_delay *= self.multiplier
|
|
self.attempt_count += 1
|
|
|
|
return delay
|
|
|
|
def reset(self) -> None:
|
|
"""Reset backoff to initial state"""
|
|
self.current_delay = self.initial_delay
|
|
self.attempt_count = 0
|
|
|
|
|
|
class ConnectionManager:
|
|
"""
|
|
Manages connection lifecycle with retry logic and health monitoring.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
max_retries: int = 10,
|
|
initial_delay: float = 1.0,
|
|
max_delay: float = 300.0,
|
|
health_check_interval: int = 30
|
|
):
|
|
"""
|
|
Initialize connection manager.
|
|
|
|
Args:
|
|
name: Connection name for logging
|
|
max_retries: Maximum number of retry attempts
|
|
initial_delay: Initial retry delay in seconds
|
|
max_delay: Maximum retry delay in seconds
|
|
health_check_interval: Health check interval in seconds
|
|
"""
|
|
self.name = name
|
|
self.max_retries = max_retries
|
|
self.health_check_interval = health_check_interval
|
|
|
|
self.backoff = ExponentialBackoff(initial_delay, max_delay)
|
|
self.is_connected = False
|
|
self.connection_attempts = 0
|
|
self.last_error: Optional[Exception] = None
|
|
self.health_check_task: Optional[asyncio.Task] = None
|
|
|
|
# Callbacks
|
|
self.on_connect: Optional[Callable] = None
|
|
self.on_disconnect: Optional[Callable] = None
|
|
self.on_error: Optional[Callable] = None
|
|
self.on_health_check: Optional[Callable] = None
|
|
|
|
logger.info(f"Connection manager '{name}' initialized")
|
|
|
|
async def connect(self, connect_func: Callable) -> bool:
|
|
"""
|
|
Attempt to establish connection with retry logic.
|
|
|
|
Args:
|
|
connect_func: Async function that establishes the connection
|
|
|
|
Returns:
|
|
bool: True if connection successful, False otherwise
|
|
"""
|
|
self.connection_attempts = 0
|
|
self.backoff.reset()
|
|
|
|
while self.connection_attempts < self.max_retries:
|
|
try:
|
|
logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})")
|
|
|
|
# Attempt connection
|
|
await connect_func()
|
|
|
|
# Connection successful
|
|
self.is_connected = True
|
|
self.connection_attempts = 0
|
|
self.last_error = None
|
|
self.backoff.reset()
|
|
|
|
logger.info(f"Connection '{self.name}' established successfully")
|
|
|
|
# Start health check
|
|
await self._start_health_check()
|
|
|
|
# Notify success
|
|
if self.on_connect:
|
|
try:
|
|
await self.on_connect()
|
|
except Exception as e:
|
|
logger.warning(f"Error in connect callback: {e}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.connection_attempts += 1
|
|
self.last_error = e
|
|
|
|
logger.warning(
|
|
f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}"
|
|
)
|
|
|
|
# Notify error
|
|
if self.on_error:
|
|
try:
|
|
await self.on_error(e)
|
|
except Exception as callback_error:
|
|
logger.warning(f"Error in error callback: {callback_error}")
|
|
|
|
# Check if we should retry
|
|
if self.connection_attempts >= self.max_retries:
|
|
logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts")
|
|
break
|
|
|
|
# Wait before retry
|
|
delay = self.backoff.get_delay()
|
|
logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds")
|
|
await asyncio.sleep(delay)
|
|
|
|
self.is_connected = False
|
|
return False
|
|
|
|
async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None:
|
|
"""
|
|
Disconnect and cleanup.
|
|
|
|
Args:
|
|
disconnect_func: Optional async function to handle disconnection
|
|
"""
|
|
logger.info(f"Disconnecting '{self.name}'")
|
|
|
|
# Stop health check
|
|
await self._stop_health_check()
|
|
|
|
# Execute disconnect function
|
|
if disconnect_func:
|
|
try:
|
|
await disconnect_func()
|
|
except Exception as e:
|
|
logger.warning(f"Error during disconnect: {e}")
|
|
|
|
self.is_connected = False
|
|
|
|
# Notify disconnect
|
|
if self.on_disconnect:
|
|
try:
|
|
await self.on_disconnect()
|
|
except Exception as e:
|
|
logger.warning(f"Error in disconnect callback: {e}")
|
|
|
|
logger.info(f"Connection '{self.name}' disconnected")
|
|
|
|
async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool:
|
|
"""
|
|
Reconnect by disconnecting first then connecting.
|
|
|
|
Args:
|
|
connect_func: Async function that establishes the connection
|
|
disconnect_func: Optional async function to handle disconnection
|
|
|
|
Returns:
|
|
bool: True if reconnection successful, False otherwise
|
|
"""
|
|
logger.info(f"Reconnecting '{self.name}'")
|
|
|
|
# Disconnect first
|
|
await self.disconnect(disconnect_func)
|
|
|
|
# Wait a bit before reconnecting
|
|
await asyncio.sleep(1.0)
|
|
|
|
# Attempt to connect
|
|
return await self.connect(connect_func)
|
|
|
|
async def _start_health_check(self) -> None:
|
|
"""Start periodic health check"""
|
|
if self.health_check_task:
|
|
return
|
|
|
|
self.health_check_task = asyncio.create_task(self._health_check_loop())
|
|
logger.debug(f"Health check started for '{self.name}'")
|
|
|
|
async def _stop_health_check(self) -> None:
|
|
"""Stop health check"""
|
|
if self.health_check_task:
|
|
self.health_check_task.cancel()
|
|
try:
|
|
await self.health_check_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self.health_check_task = None
|
|
logger.debug(f"Health check stopped for '{self.name}'")
|
|
|
|
async def _health_check_loop(self) -> None:
|
|
"""Health check loop"""
|
|
while self.is_connected:
|
|
try:
|
|
await asyncio.sleep(self.health_check_interval)
|
|
|
|
if self.on_health_check:
|
|
is_healthy = await self.on_health_check()
|
|
if not is_healthy:
|
|
logger.warning(f"Health check failed for '{self.name}'")
|
|
self.is_connected = False
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Health check error for '{self.name}': {e}")
|
|
self.is_connected = False
|
|
break
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get connection statistics"""
|
|
return {
|
|
'name': self.name,
|
|
'is_connected': self.is_connected,
|
|
'connection_attempts': self.connection_attempts,
|
|
'max_retries': self.max_retries,
|
|
'current_delay': self.backoff.current_delay,
|
|
'backoff_attempts': self.backoff.attempt_count,
|
|
'last_error': str(self.last_error) if self.last_error else None,
|
|
'health_check_active': self.health_check_task is not None
|
|
} |