cob integration scaffold

This commit is contained in:
Dobromir Popov
2025-08-04 17:12:26 +03:00
parent de9fa4a421
commit 504736c0f7
14 changed files with 2665 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
"""
Exchange connector implementations for the COBY system.
"""
from .base_connector import BaseExchangeConnector
from .connection_manager import ConnectionManager
from .circuit_breaker import CircuitBreaker
__all__ = [
'BaseExchangeConnector',
'ConnectionManager',
'CircuitBreaker'
]

View File

@@ -0,0 +1,383 @@
"""
Base exchange connector implementation with connection management and error handling.
"""
import asyncio
import json
import websockets
from typing import Dict, List, Optional, Callable, Any
from datetime import datetime, timezone
from ..interfaces.exchange_connector import ExchangeConnector
from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ConnectionError, ValidationError
from ..utils.timing import get_current_timestamp
from .connection_manager import ConnectionManager
from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError
logger = get_logger(__name__)
class BaseExchangeConnector(ExchangeConnector):
"""
Base implementation of exchange connector with common functionality.
Provides:
- WebSocket connection management
- Exponential backoff retry logic
- Circuit breaker pattern
- Health monitoring
- Message handling framework
- Subscription management
"""
def __init__(self, exchange_name: str, websocket_url: str):
"""
Initialize base exchange connector.
Args:
exchange_name: Name of the exchange
websocket_url: WebSocket URL for the exchange
"""
super().__init__(exchange_name)
self.websocket_url = websocket_url
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types]
self.message_handlers: Dict[str, Callable] = {}
# Connection management
self.connection_manager = ConnectionManager(
name=f"{exchange_name}_connector",
max_retries=10,
initial_delay=1.0,
max_delay=300.0,
health_check_interval=30
)
# Circuit breaker
self.circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=60,
expected_exception=Exception,
name=f"{exchange_name}_circuit"
)
# Statistics
self.message_count = 0
self.error_count = 0
self.last_message_time: Optional[datetime] = None
# Setup callbacks
self.connection_manager.on_connect = self._on_connect
self.connection_manager.on_disconnect = self._on_disconnect
self.connection_manager.on_error = self._on_error
self.connection_manager.on_health_check = self._health_check
# Message processing
self._message_queue = asyncio.Queue(maxsize=10000)
self._message_processor_task: Optional[asyncio.Task] = None
logger.info(f"Base connector initialized for {exchange_name}")
async def connect(self) -> bool:
"""Establish connection to the exchange WebSocket"""
try:
set_correlation_id()
logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}")
return await self.connection_manager.connect(self._establish_websocket_connection)
except Exception as e:
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
self._notify_status_callbacks(ConnectionStatus.ERROR)
return False
async def disconnect(self) -> None:
"""Disconnect from the exchange WebSocket"""
try:
set_correlation_id()
logger.info(f"Disconnecting from {self.exchange_name}")
await self.connection_manager.disconnect(self._close_websocket_connection)
except Exception as e:
logger.error(f"Error during disconnect from {self.exchange_name}: {e}")
async def _establish_websocket_connection(self) -> None:
"""Establish WebSocket connection"""
try:
# Use circuit breaker for connection
self.websocket = await self.circuit_breaker.call_async(
websockets.connect,
self.websocket_url,
ping_interval=20,
ping_timeout=10,
close_timeout=10
)
logger.info(f"WebSocket connected to {self.exchange_name}")
# Start message processing
await self._start_message_processing()
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker open for {self.exchange_name}: {e}")
raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN")
except Exception as e:
logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}")
raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED")
async def _close_websocket_connection(self) -> None:
"""Close WebSocket connection"""
try:
# Stop message processing
await self._stop_message_processing()
# Close WebSocket
if self.websocket:
await self.websocket.close()
self.websocket = None
logger.info(f"WebSocket disconnected from {self.exchange_name}")
except Exception as e:
logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}")
async def _start_message_processing(self) -> None:
"""Start message processing tasks"""
if self._message_processor_task:
return
# Start message processor
self._message_processor_task = asyncio.create_task(self._message_processor())
# Start message receiver
asyncio.create_task(self._message_receiver())
logger.debug(f"Message processing started for {self.exchange_name}")
async def _stop_message_processing(self) -> None:
"""Stop message processing tasks"""
if self._message_processor_task:
self._message_processor_task.cancel()
try:
await self._message_processor_task
except asyncio.CancelledError:
pass
self._message_processor_task = None
logger.debug(f"Message processing stopped for {self.exchange_name}")
async def _message_receiver(self) -> None:
"""Receive messages from WebSocket"""
try:
while self.websocket and not self.websocket.closed:
try:
message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0)
# Queue message for processing
try:
self._message_queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning(f"Message queue full for {self.exchange_name}, dropping message")
except asyncio.TimeoutError:
# Send ping to keep connection alive
if self.websocket:
await self.websocket.ping()
except websockets.exceptions.ConnectionClosed:
logger.warning(f"WebSocket connection closed for {self.exchange_name}")
break
except Exception as e:
logger.error(f"Error receiving message from {self.exchange_name}: {e}")
self.error_count += 1
break
except Exception as e:
logger.error(f"Message receiver error for {self.exchange_name}: {e}")
finally:
# Mark as disconnected
self.connection_manager.is_connected = False
async def _message_processor(self) -> None:
"""Process messages from the queue"""
while True:
try:
# Get message from queue
message = await self._message_queue.get()
# Process message
await self._process_message(message)
# Update statistics
self.message_count += 1
self.last_message_time = get_current_timestamp()
# Mark task as done
self._message_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error processing message for {self.exchange_name}: {e}")
self.error_count += 1
async def _process_message(self, message: str) -> None:
"""
Process incoming WebSocket message.
Args:
message: Raw message string
"""
try:
# Parse JSON message
data = json.loads(message)
# Determine message type and route to appropriate handler
message_type = self._get_message_type(data)
if message_type in self.message_handlers:
await self.message_handlers[message_type](data)
else:
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
except Exception as e:
logger.error(f"Error processing message from {self.exchange_name}: {e}")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from message data.
Override in subclasses for exchange-specific logic.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Default implementation - override in subclasses
return data.get('type', 'unknown')
async def _send_message(self, message: Dict) -> bool:
"""
Send message to WebSocket.
Args:
message: Message to send
Returns:
bool: True if sent successfully, False otherwise
"""
try:
if not self.websocket or self.websocket.closed:
logger.warning(f"Cannot send message to {self.exchange_name}: not connected")
return False
message_str = json.dumps(message)
await self.websocket.send(message_str)
logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...")
return True
except Exception as e:
logger.error(f"Error sending message to {self.exchange_name}: {e}")
return False
# Callback handlers
async def _on_connect(self) -> None:
"""Handle successful connection"""
self._notify_status_callbacks(ConnectionStatus.CONNECTED)
# Resubscribe to all previous subscriptions
await self._resubscribe_all()
async def _on_disconnect(self) -> None:
"""Handle disconnection"""
self._notify_status_callbacks(ConnectionStatus.DISCONNECTED)
async def _on_error(self, error: Exception) -> None:
"""Handle connection error"""
logger.error(f"Connection error for {self.exchange_name}: {error}")
self._notify_status_callbacks(ConnectionStatus.ERROR)
async def _health_check(self) -> bool:
"""Perform health check"""
try:
if not self.websocket or self.websocket.closed:
return False
# Check if we've received messages recently
if self.last_message_time:
time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds()
if time_since_last_message > 60: # No messages for 60 seconds
logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s")
return False
# Send ping
await self.websocket.ping()
return True
except Exception as e:
logger.error(f"Health check failed for {self.exchange_name}: {e}")
return False
async def _resubscribe_all(self) -> None:
"""Resubscribe to all previous subscriptions after reconnection"""
for symbol, subscription_types in self.subscriptions.items():
for sub_type in subscription_types:
try:
if sub_type == 'orderbook':
await self.subscribe_orderbook(symbol)
elif sub_type == 'trades':
await self.subscribe_trades(symbol)
except Exception as e:
logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}")
# Abstract methods that must be implemented by subclasses
async def subscribe_orderbook(self, symbol: str) -> None:
"""Subscribe to order book updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement subscribe_orderbook")
async def subscribe_trades(self, symbol: str) -> None:
"""Subscribe to trade updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement subscribe_trades")
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""Unsubscribe from order book updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement unsubscribe_orderbook")
async def unsubscribe_trades(self, symbol: str) -> None:
"""Unsubscribe from trade updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement unsubscribe_trades")
async def get_symbols(self) -> List[str]:
"""Get available symbols - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement get_symbols")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement normalize_symbol")
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""Get order book snapshot - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement get_orderbook_snapshot")
# Utility methods
def get_stats(self) -> Dict[str, Any]:
"""Get connector statistics"""
return {
'exchange': self.exchange_name,
'connection_status': self.get_connection_status().value,
'is_connected': self.is_connected,
'message_count': self.message_count,
'error_count': self.error_count,
'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None,
'subscriptions': dict(self.subscriptions),
'connection_manager': self.connection_manager.get_stats(),
'circuit_breaker': self.circuit_breaker.get_stats(),
'queue_size': self._message_queue.qsize()
}

View File

@@ -0,0 +1,206 @@
"""
Circuit breaker pattern implementation for exchange connections.
"""
import time
from enum import Enum
from typing import Optional, Callable, Any
from ..utils.logging import get_logger
logger = get_logger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Circuit is open, calls fail fast
HALF_OPEN = "half_open" # Testing if service is back
class CircuitBreaker:
"""
Circuit breaker to prevent cascading failures in exchange connections.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Circuit is open, requests fail immediately
- HALF_OPEN: Testing if service is back, limited requests allowed
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type = Exception,
name: str = "CircuitBreaker"
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Time in seconds before attempting recovery
expected_exception: Exception type that triggers circuit breaker
name: Name for logging purposes
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.name = name
# State tracking
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time: Optional[float] = None
self._next_attempt_time: Optional[float] = None
logger.info(f"Circuit breaker '{name}' initialized with threshold={failure_threshold}")
@property
def state(self) -> CircuitState:
"""Get current circuit state"""
return self._state
@property
def failure_count(self) -> int:
"""Get current failure count"""
return self._failure_count
def _should_attempt_reset(self) -> bool:
"""Check if we should attempt to reset the circuit"""
if self._state != CircuitState.OPEN:
return False
if self._next_attempt_time is None:
return False
return time.time() >= self._next_attempt_time
def _on_success(self) -> None:
"""Handle successful operation"""
if self._state == CircuitState.HALF_OPEN:
logger.info(f"Circuit breaker '{self.name}' reset to CLOSED after successful test")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._next_attempt_time = None
def _on_failure(self) -> None:
"""Handle failed operation"""
self._failure_count += 1
self._last_failure_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Failed during test, go back to OPEN
logger.warning(f"Circuit breaker '{self.name}' failed during test, returning to OPEN")
self._state = CircuitState.OPEN
self._next_attempt_time = time.time() + self.recovery_timeout
elif self._failure_count >= self.failure_threshold:
# Too many failures, open the circuit
logger.error(
f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures"
)
self._state = CircuitState.OPEN
self._next_attempt_time = time.time() + self.recovery_timeout
def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Function to execute
*args: Function arguments
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
CircuitBreakerOpenError: When circuit is open
Original exception: When function fails
"""
# Check if we should attempt reset
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
self._state = CircuitState.HALF_OPEN
# Fail fast if circuit is open
if self._state == CircuitState.OPEN:
raise CircuitBreakerOpenError(
f"Circuit breaker '{self.name}' is OPEN. "
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
)
try:
# Execute the function
result = func(*args, **kwargs)
self._on_success()
return result
except self.expected_exception as e:
self._on_failure()
raise e
async def call_async(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute async function with circuit breaker protection.
Args:
func: Async function to execute
*args: Function arguments
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
CircuitBreakerOpenError: When circuit is open
Original exception: When function fails
"""
# Check if we should attempt reset
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
self._state = CircuitState.HALF_OPEN
# Fail fast if circuit is open
if self._state == CircuitState.OPEN:
raise CircuitBreakerOpenError(
f"Circuit breaker '{self.name}' is OPEN. "
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
)
try:
# Execute the async function
result = await func(*args, **kwargs)
self._on_success()
return result
except self.expected_exception as e:
self._on_failure()
raise e
def reset(self) -> None:
"""Manually reset the circuit breaker"""
logger.info(f"Circuit breaker '{self.name}' manually reset")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._next_attempt_time = None
def get_stats(self) -> dict:
"""Get circuit breaker statistics"""
return {
'name': self.name,
'state': self._state.value,
'failure_count': self._failure_count,
'failure_threshold': self.failure_threshold,
'last_failure_time': self._last_failure_time,
'next_attempt_time': self._next_attempt_time,
'recovery_timeout': self.recovery_timeout
}
class CircuitBreakerOpenError(Exception):
"""Exception raised when circuit breaker is open"""
pass

View File

@@ -0,0 +1,271 @@
"""
Connection management with exponential backoff and retry logic.
"""
import asyncio
import random
from typing import Optional, Callable, Any
from ..utils.logging import get_logger
from ..utils.exceptions import ConnectionError
logger = get_logger(__name__)
class ExponentialBackoff:
"""Exponential backoff strategy for connection retries"""
def __init__(
self,
initial_delay: float = 1.0,
max_delay: float = 300.0,
multiplier: float = 2.0,
jitter: bool = True
):
"""
Initialize exponential backoff.
Args:
initial_delay: Initial delay in seconds
max_delay: Maximum delay in seconds
multiplier: Backoff multiplier
jitter: Whether to add random jitter
"""
self.initial_delay = initial_delay
self.max_delay = max_delay
self.multiplier = multiplier
self.jitter = jitter
self.current_delay = initial_delay
self.attempt_count = 0
def get_delay(self) -> float:
"""Get next delay value"""
delay = min(self.current_delay, self.max_delay)
# Add jitter to prevent thundering herd
if self.jitter:
delay = delay * (0.5 + random.random() * 0.5)
# Update for next attempt
self.current_delay *= self.multiplier
self.attempt_count += 1
return delay
def reset(self) -> None:
"""Reset backoff to initial state"""
self.current_delay = self.initial_delay
self.attempt_count = 0
class ConnectionManager:
"""
Manages connection lifecycle with retry logic and health monitoring.
"""
def __init__(
self,
name: str,
max_retries: int = 10,
initial_delay: float = 1.0,
max_delay: float = 300.0,
health_check_interval: int = 30
):
"""
Initialize connection manager.
Args:
name: Connection name for logging
max_retries: Maximum number of retry attempts
initial_delay: Initial retry delay in seconds
max_delay: Maximum retry delay in seconds
health_check_interval: Health check interval in seconds
"""
self.name = name
self.max_retries = max_retries
self.health_check_interval = health_check_interval
self.backoff = ExponentialBackoff(initial_delay, max_delay)
self.is_connected = False
self.connection_attempts = 0
self.last_error: Optional[Exception] = None
self.health_check_task: Optional[asyncio.Task] = None
# Callbacks
self.on_connect: Optional[Callable] = None
self.on_disconnect: Optional[Callable] = None
self.on_error: Optional[Callable] = None
self.on_health_check: Optional[Callable] = None
logger.info(f"Connection manager '{name}' initialized")
async def connect(self, connect_func: Callable) -> bool:
"""
Attempt to establish connection with retry logic.
Args:
connect_func: Async function that establishes the connection
Returns:
bool: True if connection successful, False otherwise
"""
self.connection_attempts = 0
self.backoff.reset()
while self.connection_attempts < self.max_retries:
try:
logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})")
# Attempt connection
await connect_func()
# Connection successful
self.is_connected = True
self.connection_attempts = 0
self.last_error = None
self.backoff.reset()
logger.info(f"Connection '{self.name}' established successfully")
# Start health check
await self._start_health_check()
# Notify success
if self.on_connect:
try:
await self.on_connect()
except Exception as e:
logger.warning(f"Error in connect callback: {e}")
return True
except Exception as e:
self.connection_attempts += 1
self.last_error = e
logger.warning(
f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}"
)
# Notify error
if self.on_error:
try:
await self.on_error(e)
except Exception as callback_error:
logger.warning(f"Error in error callback: {callback_error}")
# Check if we should retry
if self.connection_attempts >= self.max_retries:
logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts")
break
# Wait before retry
delay = self.backoff.get_delay()
logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds")
await asyncio.sleep(delay)
self.is_connected = False
return False
async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None:
"""
Disconnect and cleanup.
Args:
disconnect_func: Optional async function to handle disconnection
"""
logger.info(f"Disconnecting '{self.name}'")
# Stop health check
await self._stop_health_check()
# Execute disconnect function
if disconnect_func:
try:
await disconnect_func()
except Exception as e:
logger.warning(f"Error during disconnect: {e}")
self.is_connected = False
# Notify disconnect
if self.on_disconnect:
try:
await self.on_disconnect()
except Exception as e:
logger.warning(f"Error in disconnect callback: {e}")
logger.info(f"Connection '{self.name}' disconnected")
async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool:
"""
Reconnect by disconnecting first then connecting.
Args:
connect_func: Async function that establishes the connection
disconnect_func: Optional async function to handle disconnection
Returns:
bool: True if reconnection successful, False otherwise
"""
logger.info(f"Reconnecting '{self.name}'")
# Disconnect first
await self.disconnect(disconnect_func)
# Wait a bit before reconnecting
await asyncio.sleep(1.0)
# Attempt to connect
return await self.connect(connect_func)
async def _start_health_check(self) -> None:
"""Start periodic health check"""
if self.health_check_task:
return
self.health_check_task = asyncio.create_task(self._health_check_loop())
logger.debug(f"Health check started for '{self.name}'")
async def _stop_health_check(self) -> None:
"""Stop health check"""
if self.health_check_task:
self.health_check_task.cancel()
try:
await self.health_check_task
except asyncio.CancelledError:
pass
self.health_check_task = None
logger.debug(f"Health check stopped for '{self.name}'")
async def _health_check_loop(self) -> None:
"""Health check loop"""
while self.is_connected:
try:
await asyncio.sleep(self.health_check_interval)
if self.on_health_check:
is_healthy = await self.on_health_check()
if not is_healthy:
logger.warning(f"Health check failed for '{self.name}'")
self.is_connected = False
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Health check error for '{self.name}': {e}")
self.is_connected = False
break
def get_stats(self) -> dict:
"""Get connection statistics"""
return {
'name': self.name,
'is_connected': self.is_connected,
'connection_attempts': self.connection_attempts,
'max_retries': self.max_retries,
'current_delay': self.backoff.current_delay,
'backoff_attempts': self.backoff.attempt_count,
'last_error': str(self.last_error) if self.last_error else None,
'health_check_active': self.health_check_task is not None
}