From 504736c0f74a156c975cf63ce7c7af9402767dc5 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 4 Aug 2025 17:12:26 +0300 Subject: [PATCH] cob integration scaffold --- .../multi-exchange-data-aggregation/tasks.md | 7 + COBY/connectors/__init__.py | 13 + COBY/connectors/base_connector.py | 383 +++++++++++ COBY/connectors/circuit_breaker.py | 206 ++++++ COBY/connectors/connection_manager.py | 271 ++++++++ COBY/requirements.txt | 34 + COBY/storage/__init__.py | 11 + COBY/storage/connection_pool.py | 140 ++++ COBY/storage/migrations.py | 271 ++++++++ COBY/storage/schema.py | 256 ++++++++ COBY/storage/timescale_manager.py | 604 ++++++++++++++++++ COBY/test_integration.py | 274 ++++++++ COBY/tests/__init__.py | 3 + COBY/tests/test_timescale_manager.py | 192 ++++++ 14 files changed, 2665 insertions(+) create mode 100644 COBY/connectors/__init__.py create mode 100644 COBY/connectors/base_connector.py create mode 100644 COBY/connectors/circuit_breaker.py create mode 100644 COBY/connectors/connection_manager.py create mode 100644 COBY/requirements.txt create mode 100644 COBY/storage/__init__.py create mode 100644 COBY/storage/connection_pool.py create mode 100644 COBY/storage/migrations.py create mode 100644 COBY/storage/schema.py create mode 100644 COBY/storage/timescale_manager.py create mode 100644 COBY/test_integration.py create mode 100644 COBY/tests/__init__.py create mode 100644 COBY/tests/test_timescale_manager.py diff --git a/.kiro/specs/multi-exchange-data-aggregation/tasks.md b/.kiro/specs/multi-exchange-data-aggregation/tasks.md index c2cd82c..a325052 100644 --- a/.kiro/specs/multi-exchange-data-aggregation/tasks.md +++ b/.kiro/specs/multi-exchange-data-aggregation/tasks.md @@ -7,15 +7,22 @@ - Create directory structure in `.\COBY` subfolder for the multi-exchange data aggregation system - Define base interfaces and data models for exchange connectors, data processing, and storage - Implement configuration management system with environment variable support + + + - _Requirements: 1.1, 6.1, 7.3_ - [ ] 2. Implement TimescaleDB integration and database schema - Create TimescaleDB connection manager with connection pooling + + + - Implement database schema creation with hypertables for time-series optimization - Write database operations for storing order book snapshots and trade events - Create database migration system for schema updates - _Requirements: 3.1, 3.2, 3.3, 3.4_ + - [ ] 3. Create base exchange connector framework - Implement abstract base class for exchange WebSocket connectors - Create connection management with exponential backoff and circuit breaker patterns diff --git a/COBY/connectors/__init__.py b/COBY/connectors/__init__.py new file mode 100644 index 0000000..d2dae8e --- /dev/null +++ b/COBY/connectors/__init__.py @@ -0,0 +1,13 @@ +""" +Exchange connector implementations for the COBY system. +""" + +from .base_connector import BaseExchangeConnector +from .connection_manager import ConnectionManager +from .circuit_breaker import CircuitBreaker + +__all__ = [ + 'BaseExchangeConnector', + 'ConnectionManager', + 'CircuitBreaker' +] \ No newline at end of file diff --git a/COBY/connectors/base_connector.py b/COBY/connectors/base_connector.py new file mode 100644 index 0000000..cd2c82b --- /dev/null +++ b/COBY/connectors/base_connector.py @@ -0,0 +1,383 @@ +""" +Base exchange connector implementation with connection management and error handling. +""" + +import asyncio +import json +import websockets +from typing import Dict, List, Optional, Callable, Any +from datetime import datetime, timezone + +from ..interfaces.exchange_connector import ExchangeConnector +from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent +from ..utils.logging import get_logger, set_correlation_id +from ..utils.exceptions import ConnectionError, ValidationError +from ..utils.timing import get_current_timestamp +from .connection_manager import ConnectionManager +from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError + +logger = get_logger(__name__) + + +class BaseExchangeConnector(ExchangeConnector): + """ + Base implementation of exchange connector with common functionality. + + Provides: + - WebSocket connection management + - Exponential backoff retry logic + - Circuit breaker pattern + - Health monitoring + - Message handling framework + - Subscription management + """ + + def __init__(self, exchange_name: str, websocket_url: str): + """ + Initialize base exchange connector. + + Args: + exchange_name: Name of the exchange + websocket_url: WebSocket URL for the exchange + """ + super().__init__(exchange_name) + + self.websocket_url = websocket_url + self.websocket: Optional[websockets.WebSocketServerProtocol] = None + self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types] + self.message_handlers: Dict[str, Callable] = {} + + # Connection management + self.connection_manager = ConnectionManager( + name=f"{exchange_name}_connector", + max_retries=10, + initial_delay=1.0, + max_delay=300.0, + health_check_interval=30 + ) + + # Circuit breaker + self.circuit_breaker = CircuitBreaker( + failure_threshold=5, + recovery_timeout=60, + expected_exception=Exception, + name=f"{exchange_name}_circuit" + ) + + # Statistics + self.message_count = 0 + self.error_count = 0 + self.last_message_time: Optional[datetime] = None + + # Setup callbacks + self.connection_manager.on_connect = self._on_connect + self.connection_manager.on_disconnect = self._on_disconnect + self.connection_manager.on_error = self._on_error + self.connection_manager.on_health_check = self._health_check + + # Message processing + self._message_queue = asyncio.Queue(maxsize=10000) + self._message_processor_task: Optional[asyncio.Task] = None + + logger.info(f"Base connector initialized for {exchange_name}") + + async def connect(self) -> bool: + """Establish connection to the exchange WebSocket""" + try: + set_correlation_id() + logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}") + + return await self.connection_manager.connect(self._establish_websocket_connection) + + except Exception as e: + logger.error(f"Failed to connect to {self.exchange_name}: {e}") + self._notify_status_callbacks(ConnectionStatus.ERROR) + return False + + async def disconnect(self) -> None: + """Disconnect from the exchange WebSocket""" + try: + set_correlation_id() + logger.info(f"Disconnecting from {self.exchange_name}") + + await self.connection_manager.disconnect(self._close_websocket_connection) + + except Exception as e: + logger.error(f"Error during disconnect from {self.exchange_name}: {e}") + + async def _establish_websocket_connection(self) -> None: + """Establish WebSocket connection""" + try: + # Use circuit breaker for connection + self.websocket = await self.circuit_breaker.call_async( + websockets.connect, + self.websocket_url, + ping_interval=20, + ping_timeout=10, + close_timeout=10 + ) + + logger.info(f"WebSocket connected to {self.exchange_name}") + + # Start message processing + await self._start_message_processing() + + except CircuitBreakerOpenError as e: + logger.error(f"Circuit breaker open for {self.exchange_name}: {e}") + raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN") + except Exception as e: + logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}") + raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED") + + async def _close_websocket_connection(self) -> None: + """Close WebSocket connection""" + try: + # Stop message processing + await self._stop_message_processing() + + # Close WebSocket + if self.websocket: + await self.websocket.close() + self.websocket = None + + logger.info(f"WebSocket disconnected from {self.exchange_name}") + + except Exception as e: + logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}") + + async def _start_message_processing(self) -> None: + """Start message processing tasks""" + if self._message_processor_task: + return + + # Start message processor + self._message_processor_task = asyncio.create_task(self._message_processor()) + + # Start message receiver + asyncio.create_task(self._message_receiver()) + + logger.debug(f"Message processing started for {self.exchange_name}") + + async def _stop_message_processing(self) -> None: + """Stop message processing tasks""" + if self._message_processor_task: + self._message_processor_task.cancel() + try: + await self._message_processor_task + except asyncio.CancelledError: + pass + self._message_processor_task = None + + logger.debug(f"Message processing stopped for {self.exchange_name}") + + async def _message_receiver(self) -> None: + """Receive messages from WebSocket""" + try: + while self.websocket and not self.websocket.closed: + try: + message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0) + + # Queue message for processing + try: + self._message_queue.put_nowait(message) + except asyncio.QueueFull: + logger.warning(f"Message queue full for {self.exchange_name}, dropping message") + + except asyncio.TimeoutError: + # Send ping to keep connection alive + if self.websocket: + await self.websocket.ping() + except websockets.exceptions.ConnectionClosed: + logger.warning(f"WebSocket connection closed for {self.exchange_name}") + break + except Exception as e: + logger.error(f"Error receiving message from {self.exchange_name}: {e}") + self.error_count += 1 + break + + except Exception as e: + logger.error(f"Message receiver error for {self.exchange_name}: {e}") + finally: + # Mark as disconnected + self.connection_manager.is_connected = False + + async def _message_processor(self) -> None: + """Process messages from the queue""" + while True: + try: + # Get message from queue + message = await self._message_queue.get() + + # Process message + await self._process_message(message) + + # Update statistics + self.message_count += 1 + self.last_message_time = get_current_timestamp() + + # Mark task as done + self._message_queue.task_done() + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error processing message for {self.exchange_name}: {e}") + self.error_count += 1 + + async def _process_message(self, message: str) -> None: + """ + Process incoming WebSocket message. + + Args: + message: Raw message string + """ + try: + # Parse JSON message + data = json.loads(message) + + # Determine message type and route to appropriate handler + message_type = self._get_message_type(data) + + if message_type in self.message_handlers: + await self.message_handlers[message_type](data) + else: + logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}") + + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}") + except Exception as e: + logger.error(f"Error processing message from {self.exchange_name}: {e}") + + def _get_message_type(self, data: Dict) -> str: + """ + Determine message type from message data. + Override in subclasses for exchange-specific logic. + + Args: + data: Parsed message data + + Returns: + str: Message type identifier + """ + # Default implementation - override in subclasses + return data.get('type', 'unknown') + + async def _send_message(self, message: Dict) -> bool: + """ + Send message to WebSocket. + + Args: + message: Message to send + + Returns: + bool: True if sent successfully, False otherwise + """ + try: + if not self.websocket or self.websocket.closed: + logger.warning(f"Cannot send message to {self.exchange_name}: not connected") + return False + + message_str = json.dumps(message) + await self.websocket.send(message_str) + + logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...") + return True + + except Exception as e: + logger.error(f"Error sending message to {self.exchange_name}: {e}") + return False + + # Callback handlers + async def _on_connect(self) -> None: + """Handle successful connection""" + self._notify_status_callbacks(ConnectionStatus.CONNECTED) + + # Resubscribe to all previous subscriptions + await self._resubscribe_all() + + async def _on_disconnect(self) -> None: + """Handle disconnection""" + self._notify_status_callbacks(ConnectionStatus.DISCONNECTED) + + async def _on_error(self, error: Exception) -> None: + """Handle connection error""" + logger.error(f"Connection error for {self.exchange_name}: {error}") + self._notify_status_callbacks(ConnectionStatus.ERROR) + + async def _health_check(self) -> bool: + """Perform health check""" + try: + if not self.websocket or self.websocket.closed: + return False + + # Check if we've received messages recently + if self.last_message_time: + time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds() + if time_since_last_message > 60: # No messages for 60 seconds + logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s") + return False + + # Send ping + await self.websocket.ping() + return True + + except Exception as e: + logger.error(f"Health check failed for {self.exchange_name}: {e}") + return False + + async def _resubscribe_all(self) -> None: + """Resubscribe to all previous subscriptions after reconnection""" + for symbol, subscription_types in self.subscriptions.items(): + for sub_type in subscription_types: + try: + if sub_type == 'orderbook': + await self.subscribe_orderbook(symbol) + elif sub_type == 'trades': + await self.subscribe_trades(symbol) + except Exception as e: + logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}") + + # Abstract methods that must be implemented by subclasses + async def subscribe_orderbook(self, symbol: str) -> None: + """Subscribe to order book updates - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement subscribe_orderbook") + + async def subscribe_trades(self, symbol: str) -> None: + """Subscribe to trade updates - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement subscribe_trades") + + async def unsubscribe_orderbook(self, symbol: str) -> None: + """Unsubscribe from order book updates - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement unsubscribe_orderbook") + + async def unsubscribe_trades(self, symbol: str) -> None: + """Unsubscribe from trade updates - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement unsubscribe_trades") + + async def get_symbols(self) -> List[str]: + """Get available symbols - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement get_symbols") + + def normalize_symbol(self, symbol: str) -> str: + """Normalize symbol format - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement normalize_symbol") + + async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]: + """Get order book snapshot - must be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement get_orderbook_snapshot") + + # Utility methods + def get_stats(self) -> Dict[str, Any]: + """Get connector statistics""" + return { + 'exchange': self.exchange_name, + 'connection_status': self.get_connection_status().value, + 'is_connected': self.is_connected, + 'message_count': self.message_count, + 'error_count': self.error_count, + 'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None, + 'subscriptions': dict(self.subscriptions), + 'connection_manager': self.connection_manager.get_stats(), + 'circuit_breaker': self.circuit_breaker.get_stats(), + 'queue_size': self._message_queue.qsize() + } \ No newline at end of file diff --git a/COBY/connectors/circuit_breaker.py b/COBY/connectors/circuit_breaker.py new file mode 100644 index 0000000..a546572 --- /dev/null +++ b/COBY/connectors/circuit_breaker.py @@ -0,0 +1,206 @@ +""" +Circuit breaker pattern implementation for exchange connections. +""" + +import time +from enum import Enum +from typing import Optional, Callable, Any +from ..utils.logging import get_logger + +logger = get_logger(__name__) + + +class CircuitState(Enum): + """Circuit breaker states""" + CLOSED = "closed" # Normal operation + OPEN = "open" # Circuit is open, calls fail fast + HALF_OPEN = "half_open" # Testing if service is back + + +class CircuitBreaker: + """ + Circuit breaker to prevent cascading failures in exchange connections. + + States: + - CLOSED: Normal operation, requests pass through + - OPEN: Circuit is open, requests fail immediately + - HALF_OPEN: Testing if service is back, limited requests allowed + """ + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout: int = 60, + expected_exception: type = Exception, + name: str = "CircuitBreaker" + ): + """ + Initialize circuit breaker. + + Args: + failure_threshold: Number of failures before opening circuit + recovery_timeout: Time in seconds before attempting recovery + expected_exception: Exception type that triggers circuit breaker + name: Name for logging purposes + """ + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.expected_exception = expected_exception + self.name = name + + # State tracking + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._last_failure_time: Optional[float] = None + self._next_attempt_time: Optional[float] = None + + logger.info(f"Circuit breaker '{name}' initialized with threshold={failure_threshold}") + + @property + def state(self) -> CircuitState: + """Get current circuit state""" + return self._state + + @property + def failure_count(self) -> int: + """Get current failure count""" + return self._failure_count + + def _should_attempt_reset(self) -> bool: + """Check if we should attempt to reset the circuit""" + if self._state != CircuitState.OPEN: + return False + + if self._next_attempt_time is None: + return False + + return time.time() >= self._next_attempt_time + + def _on_success(self) -> None: + """Handle successful operation""" + if self._state == CircuitState.HALF_OPEN: + logger.info(f"Circuit breaker '{self.name}' reset to CLOSED after successful test") + self._state = CircuitState.CLOSED + + self._failure_count = 0 + self._last_failure_time = None + self._next_attempt_time = None + + def _on_failure(self) -> None: + """Handle failed operation""" + self._failure_count += 1 + self._last_failure_time = time.time() + + if self._state == CircuitState.HALF_OPEN: + # Failed during test, go back to OPEN + logger.warning(f"Circuit breaker '{self.name}' failed during test, returning to OPEN") + self._state = CircuitState.OPEN + self._next_attempt_time = time.time() + self.recovery_timeout + elif self._failure_count >= self.failure_threshold: + # Too many failures, open the circuit + logger.error( + f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures" + ) + self._state = CircuitState.OPEN + self._next_attempt_time = time.time() + self.recovery_timeout + + def call(self, func: Callable, *args, **kwargs) -> Any: + """ + Execute function with circuit breaker protection. + + Args: + func: Function to execute + *args: Function arguments + **kwargs: Function keyword arguments + + Returns: + Function result + + Raises: + CircuitBreakerOpenError: When circuit is open + Original exception: When function fails + """ + # Check if we should attempt reset + if self._should_attempt_reset(): + logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN") + self._state = CircuitState.HALF_OPEN + + # Fail fast if circuit is open + if self._state == CircuitState.OPEN: + raise CircuitBreakerOpenError( + f"Circuit breaker '{self.name}' is OPEN. " + f"Next attempt in {self._next_attempt_time - time.time():.1f}s" + ) + + try: + # Execute the function + result = func(*args, **kwargs) + self._on_success() + return result + + except self.expected_exception as e: + self._on_failure() + raise e + + async def call_async(self, func: Callable, *args, **kwargs) -> Any: + """ + Execute async function with circuit breaker protection. + + Args: + func: Async function to execute + *args: Function arguments + **kwargs: Function keyword arguments + + Returns: + Function result + + Raises: + CircuitBreakerOpenError: When circuit is open + Original exception: When function fails + """ + # Check if we should attempt reset + if self._should_attempt_reset(): + logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN") + self._state = CircuitState.HALF_OPEN + + # Fail fast if circuit is open + if self._state == CircuitState.OPEN: + raise CircuitBreakerOpenError( + f"Circuit breaker '{self.name}' is OPEN. " + f"Next attempt in {self._next_attempt_time - time.time():.1f}s" + ) + + try: + # Execute the async function + result = await func(*args, **kwargs) + self._on_success() + return result + + except self.expected_exception as e: + self._on_failure() + raise e + + def reset(self) -> None: + """Manually reset the circuit breaker""" + logger.info(f"Circuit breaker '{self.name}' manually reset") + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._last_failure_time = None + self._next_attempt_time = None + + def get_stats(self) -> dict: + """Get circuit breaker statistics""" + return { + 'name': self.name, + 'state': self._state.value, + 'failure_count': self._failure_count, + 'failure_threshold': self.failure_threshold, + 'last_failure_time': self._last_failure_time, + 'next_attempt_time': self._next_attempt_time, + 'recovery_timeout': self.recovery_timeout + } + + +class CircuitBreakerOpenError(Exception): + """Exception raised when circuit breaker is open""" + pass \ No newline at end of file diff --git a/COBY/connectors/connection_manager.py b/COBY/connectors/connection_manager.py new file mode 100644 index 0000000..7b838ea --- /dev/null +++ b/COBY/connectors/connection_manager.py @@ -0,0 +1,271 @@ +""" +Connection management with exponential backoff and retry logic. +""" + +import asyncio +import random +from typing import Optional, Callable, Any +from ..utils.logging import get_logger +from ..utils.exceptions import ConnectionError + +logger = get_logger(__name__) + + +class ExponentialBackoff: + """Exponential backoff strategy for connection retries""" + + def __init__( + self, + initial_delay: float = 1.0, + max_delay: float = 300.0, + multiplier: float = 2.0, + jitter: bool = True + ): + """ + Initialize exponential backoff. + + Args: + initial_delay: Initial delay in seconds + max_delay: Maximum delay in seconds + multiplier: Backoff multiplier + jitter: Whether to add random jitter + """ + self.initial_delay = initial_delay + self.max_delay = max_delay + self.multiplier = multiplier + self.jitter = jitter + self.current_delay = initial_delay + self.attempt_count = 0 + + def get_delay(self) -> float: + """Get next delay value""" + delay = min(self.current_delay, self.max_delay) + + # Add jitter to prevent thundering herd + if self.jitter: + delay = delay * (0.5 + random.random() * 0.5) + + # Update for next attempt + self.current_delay *= self.multiplier + self.attempt_count += 1 + + return delay + + def reset(self) -> None: + """Reset backoff to initial state""" + self.current_delay = self.initial_delay + self.attempt_count = 0 + + +class ConnectionManager: + """ + Manages connection lifecycle with retry logic and health monitoring. + """ + + def __init__( + self, + name: str, + max_retries: int = 10, + initial_delay: float = 1.0, + max_delay: float = 300.0, + health_check_interval: int = 30 + ): + """ + Initialize connection manager. + + Args: + name: Connection name for logging + max_retries: Maximum number of retry attempts + initial_delay: Initial retry delay in seconds + max_delay: Maximum retry delay in seconds + health_check_interval: Health check interval in seconds + """ + self.name = name + self.max_retries = max_retries + self.health_check_interval = health_check_interval + + self.backoff = ExponentialBackoff(initial_delay, max_delay) + self.is_connected = False + self.connection_attempts = 0 + self.last_error: Optional[Exception] = None + self.health_check_task: Optional[asyncio.Task] = None + + # Callbacks + self.on_connect: Optional[Callable] = None + self.on_disconnect: Optional[Callable] = None + self.on_error: Optional[Callable] = None + self.on_health_check: Optional[Callable] = None + + logger.info(f"Connection manager '{name}' initialized") + + async def connect(self, connect_func: Callable) -> bool: + """ + Attempt to establish connection with retry logic. + + Args: + connect_func: Async function that establishes the connection + + Returns: + bool: True if connection successful, False otherwise + """ + self.connection_attempts = 0 + self.backoff.reset() + + while self.connection_attempts < self.max_retries: + try: + logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})") + + # Attempt connection + await connect_func() + + # Connection successful + self.is_connected = True + self.connection_attempts = 0 + self.last_error = None + self.backoff.reset() + + logger.info(f"Connection '{self.name}' established successfully") + + # Start health check + await self._start_health_check() + + # Notify success + if self.on_connect: + try: + await self.on_connect() + except Exception as e: + logger.warning(f"Error in connect callback: {e}") + + return True + + except Exception as e: + self.connection_attempts += 1 + self.last_error = e + + logger.warning( + f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}" + ) + + # Notify error + if self.on_error: + try: + await self.on_error(e) + except Exception as callback_error: + logger.warning(f"Error in error callback: {callback_error}") + + # Check if we should retry + if self.connection_attempts >= self.max_retries: + logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts") + break + + # Wait before retry + delay = self.backoff.get_delay() + logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds") + await asyncio.sleep(delay) + + self.is_connected = False + return False + + async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None: + """ + Disconnect and cleanup. + + Args: + disconnect_func: Optional async function to handle disconnection + """ + logger.info(f"Disconnecting '{self.name}'") + + # Stop health check + await self._stop_health_check() + + # Execute disconnect function + if disconnect_func: + try: + await disconnect_func() + except Exception as e: + logger.warning(f"Error during disconnect: {e}") + + self.is_connected = False + + # Notify disconnect + if self.on_disconnect: + try: + await self.on_disconnect() + except Exception as e: + logger.warning(f"Error in disconnect callback: {e}") + + logger.info(f"Connection '{self.name}' disconnected") + + async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool: + """ + Reconnect by disconnecting first then connecting. + + Args: + connect_func: Async function that establishes the connection + disconnect_func: Optional async function to handle disconnection + + Returns: + bool: True if reconnection successful, False otherwise + """ + logger.info(f"Reconnecting '{self.name}'") + + # Disconnect first + await self.disconnect(disconnect_func) + + # Wait a bit before reconnecting + await asyncio.sleep(1.0) + + # Attempt to connect + return await self.connect(connect_func) + + async def _start_health_check(self) -> None: + """Start periodic health check""" + if self.health_check_task: + return + + self.health_check_task = asyncio.create_task(self._health_check_loop()) + logger.debug(f"Health check started for '{self.name}'") + + async def _stop_health_check(self) -> None: + """Stop health check""" + if self.health_check_task: + self.health_check_task.cancel() + try: + await self.health_check_task + except asyncio.CancelledError: + pass + self.health_check_task = None + logger.debug(f"Health check stopped for '{self.name}'") + + async def _health_check_loop(self) -> None: + """Health check loop""" + while self.is_connected: + try: + await asyncio.sleep(self.health_check_interval) + + if self.on_health_check: + is_healthy = await self.on_health_check() + if not is_healthy: + logger.warning(f"Health check failed for '{self.name}'") + self.is_connected = False + break + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Health check error for '{self.name}': {e}") + self.is_connected = False + break + + def get_stats(self) -> dict: + """Get connection statistics""" + return { + 'name': self.name, + 'is_connected': self.is_connected, + 'connection_attempts': self.connection_attempts, + 'max_retries': self.max_retries, + 'current_delay': self.backoff.current_delay, + 'backoff_attempts': self.backoff.attempt_count, + 'last_error': str(self.last_error) if self.last_error else None, + 'health_check_active': self.health_check_task is not None + } \ No newline at end of file diff --git a/COBY/requirements.txt b/COBY/requirements.txt new file mode 100644 index 0000000..d55d855 --- /dev/null +++ b/COBY/requirements.txt @@ -0,0 +1,34 @@ +# Core dependencies for COBY system +asyncpg>=0.29.0 # PostgreSQL/TimescaleDB async driver +redis>=5.0.0 # Redis client +websockets>=12.0 # WebSocket client library +aiohttp>=3.9.0 # Async HTTP client/server +fastapi>=0.104.0 # API framework +uvicorn>=0.24.0 # ASGI server +pydantic>=2.5.0 # Data validation +python-multipart>=0.0.6 # Form data parsing + +# Data processing +pandas>=2.1.0 # Data manipulation +numpy>=1.24.0 # Numerical computing +scipy>=1.11.0 # Scientific computing + +# Utilities +python-dotenv>=1.0.0 # Environment variable loading +structlog>=23.2.0 # Structured logging +click>=8.1.0 # CLI framework +rich>=13.7.0 # Rich text and beautiful formatting + +# Development dependencies +pytest>=7.4.0 # Testing framework +pytest-asyncio>=0.21.0 # Async testing +pytest-cov>=4.1.0 # Coverage reporting +black>=23.11.0 # Code formatting +isort>=5.12.0 # Import sorting +flake8>=6.1.0 # Linting +mypy>=1.7.0 # Type checking + +# Optional dependencies for enhanced features +prometheus-client>=0.19.0 # Metrics collection +grafana-api>=1.0.3 # Grafana integration +psutil>=5.9.0 # System monitoring \ No newline at end of file diff --git a/COBY/storage/__init__.py b/COBY/storage/__init__.py new file mode 100644 index 0000000..9d35ca4 --- /dev/null +++ b/COBY/storage/__init__.py @@ -0,0 +1,11 @@ +""" +Storage layer for the COBY system. +""" + +from .timescale_manager import TimescaleManager +from .connection_pool import DatabaseConnectionPool + +__all__ = [ + 'TimescaleManager', + 'DatabaseConnectionPool' +] \ No newline at end of file diff --git a/COBY/storage/connection_pool.py b/COBY/storage/connection_pool.py new file mode 100644 index 0000000..407785a --- /dev/null +++ b/COBY/storage/connection_pool.py @@ -0,0 +1,140 @@ +""" +Database connection pool management for TimescaleDB. +""" + +import asyncio +import asyncpg +from typing import Optional, Dict, Any +from contextlib import asynccontextmanager +from ..config import config +from ..utils.logging import get_logger +from ..utils.exceptions import StorageError + +logger = get_logger(__name__) + + +class DatabaseConnectionPool: + """Manages database connection pool for TimescaleDB""" + + def __init__(self): + self._pool: Optional[asyncpg.Pool] = None + self._is_initialized = False + + async def initialize(self) -> None: + """Initialize the connection pool""" + if self._is_initialized: + return + + try: + # Build connection string + dsn = ( + f"postgresql://{config.database.user}:{config.database.password}" + f"@{config.database.host}:{config.database.port}/{config.database.name}" + ) + + # Create connection pool + self._pool = await asyncpg.create_pool( + dsn, + min_size=5, + max_size=config.database.pool_size, + max_queries=50000, + max_inactive_connection_lifetime=300, + command_timeout=config.database.pool_timeout, + server_settings={ + 'search_path': config.database.schema, + 'timezone': 'UTC' + } + ) + + self._is_initialized = True + logger.info(f"Database connection pool initialized with {config.database.pool_size} connections") + + # Test connection + await self.health_check() + + except Exception as e: + logger.error(f"Failed to initialize database connection pool: {e}") + raise StorageError(f"Database connection failed: {e}", "DB_INIT_ERROR") + + async def close(self) -> None: + """Close the connection pool""" + if self._pool: + await self._pool.close() + self._pool = None + self._is_initialized = False + logger.info("Database connection pool closed") + + @asynccontextmanager + async def get_connection(self): + """Get a database connection from the pool""" + if not self._is_initialized: + await self.initialize() + + if not self._pool: + raise StorageError("Connection pool not initialized", "POOL_NOT_READY") + + async with self._pool.acquire() as connection: + try: + yield connection + except Exception as e: + logger.error(f"Database operation failed: {e}") + raise + + @asynccontextmanager + async def get_transaction(self): + """Get a database transaction""" + async with self.get_connection() as conn: + async with conn.transaction(): + yield conn + + async def execute_query(self, query: str, *args) -> Any: + """Execute a query and return results""" + async with self.get_connection() as conn: + return await conn.fetch(query, *args) + + async def execute_command(self, command: str, *args) -> str: + """Execute a command and return status""" + async with self.get_connection() as conn: + return await conn.execute(command, *args) + + async def execute_many(self, command: str, args_list) -> None: + """Execute a command multiple times with different arguments""" + async with self.get_connection() as conn: + await conn.executemany(command, args_list) + + async def health_check(self) -> bool: + """Check database health""" + try: + async with self.get_connection() as conn: + result = await conn.fetchval("SELECT 1") + if result == 1: + logger.debug("Database health check passed") + return True + else: + logger.warning("Database health check returned unexpected result") + return False + except Exception as e: + logger.error(f"Database health check failed: {e}") + return False + + async def get_pool_stats(self) -> Dict[str, Any]: + """Get connection pool statistics""" + if not self._pool: + return {} + + return { + 'size': self._pool.get_size(), + 'min_size': self._pool.get_min_size(), + 'max_size': self._pool.get_max_size(), + 'idle_size': self._pool.get_idle_size(), + 'is_closing': self._pool.is_closing() + } + + @property + def is_initialized(self) -> bool: + """Check if pool is initialized""" + return self._is_initialized + + +# Global connection pool instance +db_pool = DatabaseConnectionPool() \ No newline at end of file diff --git a/COBY/storage/migrations.py b/COBY/storage/migrations.py new file mode 100644 index 0000000..da83898 --- /dev/null +++ b/COBY/storage/migrations.py @@ -0,0 +1,271 @@ +""" +Database migration system for schema updates. +""" + +from typing import List, Dict, Any +from datetime import datetime +from ..utils.logging import get_logger +from ..utils.exceptions import StorageError +from .connection_pool import db_pool + +logger = get_logger(__name__) + + +class Migration: + """Base class for database migrations""" + + def __init__(self, version: str, description: str): + self.version = version + self.description = description + + async def up(self) -> None: + """Apply the migration""" + raise NotImplementedError + + async def down(self) -> None: + """Rollback the migration""" + raise NotImplementedError + + +class MigrationManager: + """Manages database schema migrations""" + + def __init__(self): + self.migrations: List[Migration] = [] + + def register_migration(self, migration: Migration) -> None: + """Register a migration""" + self.migrations.append(migration) + # Sort by version + self.migrations.sort(key=lambda m: m.version) + + async def initialize_migration_table(self) -> None: + """Create migration tracking table""" + query = """ + CREATE TABLE IF NOT EXISTS market_data.schema_migrations ( + version VARCHAR(50) PRIMARY KEY, + description TEXT NOT NULL, + applied_at TIMESTAMPTZ DEFAULT NOW() + ); + """ + + await db_pool.execute_command(query) + logger.debug("Migration table initialized") + + async def get_applied_migrations(self) -> List[str]: + """Get list of applied migration versions""" + try: + query = "SELECT version FROM market_data.schema_migrations ORDER BY version" + rows = await db_pool.execute_query(query) + return [row['version'] for row in rows] + except Exception: + # Table might not exist yet + return [] + + async def apply_migration(self, migration: Migration) -> bool: + """Apply a single migration""" + try: + logger.info(f"Applying migration {migration.version}: {migration.description}") + + async with db_pool.get_transaction() as conn: + # Apply the migration + await migration.up() + + # Record the migration + await conn.execute( + "INSERT INTO market_data.schema_migrations (version, description) VALUES ($1, $2)", + migration.version, + migration.description + ) + + logger.info(f"Migration {migration.version} applied successfully") + return True + + except Exception as e: + logger.error(f"Failed to apply migration {migration.version}: {e}") + return False + + async def rollback_migration(self, migration: Migration) -> bool: + """Rollback a single migration""" + try: + logger.info(f"Rolling back migration {migration.version}: {migration.description}") + + async with db_pool.get_transaction() as conn: + # Rollback the migration + await migration.down() + + # Remove the migration record + await conn.execute( + "DELETE FROM market_data.schema_migrations WHERE version = $1", + migration.version + ) + + logger.info(f"Migration {migration.version} rolled back successfully") + return True + + except Exception as e: + logger.error(f"Failed to rollback migration {migration.version}: {e}") + return False + + async def migrate_up(self, target_version: str = None) -> bool: + """Apply all pending migrations up to target version""" + try: + await self.initialize_migration_table() + applied_migrations = await self.get_applied_migrations() + + pending_migrations = [ + m for m in self.migrations + if m.version not in applied_migrations + ] + + if target_version: + pending_migrations = [ + m for m in pending_migrations + if m.version <= target_version + ] + + if not pending_migrations: + logger.info("No pending migrations to apply") + return True + + logger.info(f"Applying {len(pending_migrations)} pending migrations") + + for migration in pending_migrations: + if not await self.apply_migration(migration): + return False + + logger.info("All migrations applied successfully") + return True + + except Exception as e: + logger.error(f"Migration failed: {e}") + return False + + async def migrate_down(self, target_version: str) -> bool: + """Rollback migrations down to target version""" + try: + applied_migrations = await self.get_applied_migrations() + + migrations_to_rollback = [ + m for m in reversed(self.migrations) + if m.version in applied_migrations and m.version > target_version + ] + + if not migrations_to_rollback: + logger.info("No migrations to rollback") + return True + + logger.info(f"Rolling back {len(migrations_to_rollback)} migrations") + + for migration in migrations_to_rollback: + if not await self.rollback_migration(migration): + return False + + logger.info("All migrations rolled back successfully") + return True + + except Exception as e: + logger.error(f"Migration rollback failed: {e}") + return False + + async def get_migration_status(self) -> Dict[str, Any]: + """Get current migration status""" + try: + applied_migrations = await self.get_applied_migrations() + + status = { + 'total_migrations': len(self.migrations), + 'applied_migrations': len(applied_migrations), + 'pending_migrations': len(self.migrations) - len(applied_migrations), + 'current_version': applied_migrations[-1] if applied_migrations else None, + 'latest_version': self.migrations[-1].version if self.migrations else None, + 'migrations': [] + } + + for migration in self.migrations: + status['migrations'].append({ + 'version': migration.version, + 'description': migration.description, + 'applied': migration.version in applied_migrations + }) + + return status + + except Exception as e: + logger.error(f"Failed to get migration status: {e}") + return {} + + +# Example migrations +class InitialSchemaMigration(Migration): + """Initial schema creation migration""" + + def __init__(self): + super().__init__("001", "Create initial schema and tables") + + async def up(self) -> None: + """Create initial schema""" + from .schema import DatabaseSchema + + queries = DatabaseSchema.get_all_creation_queries() + for query in queries: + await db_pool.execute_command(query) + + async def down(self) -> None: + """Drop initial schema""" + # Drop tables in reverse order + tables = [ + 'system_metrics', + 'exchange_status', + 'ohlcv_data', + 'heatmap_data', + 'trade_events', + 'order_book_snapshots' + ] + + for table in tables: + await db_pool.execute_command(f"DROP TABLE IF EXISTS market_data.{table} CASCADE") + + +class AddIndexesMigration(Migration): + """Add performance indexes migration""" + + def __init__(self): + super().__init__("002", "Add performance indexes") + + async def up(self) -> None: + """Add indexes""" + from .schema import DatabaseSchema + + queries = DatabaseSchema.get_index_creation_queries() + for query in queries: + await db_pool.execute_command(query) + + async def down(self) -> None: + """Drop indexes""" + indexes = [ + 'idx_order_book_symbol_exchange', + 'idx_order_book_timestamp', + 'idx_trade_events_symbol_exchange', + 'idx_trade_events_timestamp', + 'idx_trade_events_price', + 'idx_heatmap_symbol_bucket', + 'idx_heatmap_timestamp', + 'idx_ohlcv_symbol_timeframe', + 'idx_ohlcv_timestamp', + 'idx_exchange_status_exchange', + 'idx_exchange_status_timestamp', + 'idx_system_metrics_name', + 'idx_system_metrics_timestamp' + ] + + for index in indexes: + await db_pool.execute_command(f"DROP INDEX IF EXISTS market_data.{index}") + + +# Global migration manager +migration_manager = MigrationManager() + +# Register default migrations +migration_manager.register_migration(InitialSchemaMigration()) +migration_manager.register_migration(AddIndexesMigration()) \ No newline at end of file diff --git a/COBY/storage/schema.py b/COBY/storage/schema.py new file mode 100644 index 0000000..f4e3413 --- /dev/null +++ b/COBY/storage/schema.py @@ -0,0 +1,256 @@ +""" +Database schema management for TimescaleDB. +""" + +from typing import List +from ..utils.logging import get_logger + +logger = get_logger(__name__) + + +class DatabaseSchema: + """Manages database schema creation and migrations""" + + @staticmethod + def get_schema_creation_queries() -> List[str]: + """Get list of queries to create the database schema""" + return [ + # Create TimescaleDB extension + "CREATE EXTENSION IF NOT EXISTS timescaledb;", + + # Create schema + "CREATE SCHEMA IF NOT EXISTS market_data;", + + # Order book snapshots table + """ + CREATE TABLE IF NOT EXISTS market_data.order_book_snapshots ( + id BIGSERIAL, + symbol VARCHAR(20) NOT NULL, + exchange VARCHAR(20) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + bids JSONB NOT NULL, + asks JSONB NOT NULL, + sequence_id BIGINT, + mid_price DECIMAL(20,8), + spread DECIMAL(20,8), + bid_volume DECIMAL(30,8), + ask_volume DECIMAL(30,8), + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (timestamp, symbol, exchange) + ); + """, + + # Trade events table + """ + CREATE TABLE IF NOT EXISTS market_data.trade_events ( + id BIGSERIAL, + symbol VARCHAR(20) NOT NULL, + exchange VARCHAR(20) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + price DECIMAL(20,8) NOT NULL, + size DECIMAL(30,8) NOT NULL, + side VARCHAR(4) NOT NULL, + trade_id VARCHAR(100) NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (timestamp, symbol, exchange, trade_id) + ); + """, + + # Aggregated heatmap data table + """ + CREATE TABLE IF NOT EXISTS market_data.heatmap_data ( + symbol VARCHAR(20) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + bucket_size DECIMAL(10,2) NOT NULL, + price_bucket DECIMAL(20,8) NOT NULL, + volume DECIMAL(30,8) NOT NULL, + side VARCHAR(3) NOT NULL, + exchange_count INTEGER NOT NULL, + exchanges JSONB, + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (timestamp, symbol, bucket_size, price_bucket, side) + ); + """, + + # OHLCV data table + """ + CREATE TABLE IF NOT EXISTS market_data.ohlcv_data ( + symbol VARCHAR(20) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + timeframe VARCHAR(10) NOT NULL, + open_price DECIMAL(20,8) NOT NULL, + high_price DECIMAL(20,8) NOT NULL, + low_price DECIMAL(20,8) NOT NULL, + close_price DECIMAL(20,8) NOT NULL, + volume DECIMAL(30,8) NOT NULL, + trade_count INTEGER, + vwap DECIMAL(20,8), + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (timestamp, symbol, timeframe) + ); + """, + + # Exchange status tracking table + """ + CREATE TABLE IF NOT EXISTS market_data.exchange_status ( + exchange VARCHAR(20) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + status VARCHAR(20) NOT NULL, + last_message_time TIMESTAMPTZ, + error_message TEXT, + connection_count INTEGER DEFAULT 0, + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (timestamp, exchange) + ); + """, + + # System metrics table + """ + CREATE TABLE IF NOT EXISTS market_data.system_metrics ( + metric_name VARCHAR(50) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + value DECIMAL(20,8) NOT NULL, + labels JSONB, + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (timestamp, metric_name) + ); + """ + ] + + @staticmethod + def get_hypertable_creation_queries() -> List[str]: + """Get queries to create hypertables""" + return [ + "SELECT create_hypertable('market_data.order_book_snapshots', 'timestamp', if_not_exists => TRUE);", + "SELECT create_hypertable('market_data.trade_events', 'timestamp', if_not_exists => TRUE);", + "SELECT create_hypertable('market_data.heatmap_data', 'timestamp', if_not_exists => TRUE);", + "SELECT create_hypertable('market_data.ohlcv_data', 'timestamp', if_not_exists => TRUE);", + "SELECT create_hypertable('market_data.exchange_status', 'timestamp', if_not_exists => TRUE);", + "SELECT create_hypertable('market_data.system_metrics', 'timestamp', if_not_exists => TRUE);" + ] + + @staticmethod + def get_index_creation_queries() -> List[str]: + """Get queries to create indexes""" + return [ + # Order book indexes + "CREATE INDEX IF NOT EXISTS idx_order_book_symbol_exchange ON market_data.order_book_snapshots (symbol, exchange, timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_order_book_timestamp ON market_data.order_book_snapshots (timestamp DESC);", + + # Trade events indexes + "CREATE INDEX IF NOT EXISTS idx_trade_events_symbol_exchange ON market_data.trade_events (symbol, exchange, timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_trade_events_timestamp ON market_data.trade_events (timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_trade_events_price ON market_data.trade_events (symbol, price, timestamp DESC);", + + # Heatmap data indexes + "CREATE INDEX IF NOT EXISTS idx_heatmap_symbol_bucket ON market_data.heatmap_data (symbol, bucket_size, timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_heatmap_timestamp ON market_data.heatmap_data (timestamp DESC);", + + # OHLCV data indexes + "CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe ON market_data.ohlcv_data (symbol, timeframe, timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_ohlcv_timestamp ON market_data.ohlcv_data (timestamp DESC);", + + # Exchange status indexes + "CREATE INDEX IF NOT EXISTS idx_exchange_status_exchange ON market_data.exchange_status (exchange, timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_exchange_status_timestamp ON market_data.exchange_status (timestamp DESC);", + + # System metrics indexes + "CREATE INDEX IF NOT EXISTS idx_system_metrics_name ON market_data.system_metrics (metric_name, timestamp DESC);", + "CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON market_data.system_metrics (timestamp DESC);" + ] + + @staticmethod + def get_retention_policy_queries() -> List[str]: + """Get queries to create retention policies""" + return [ + "SELECT add_retention_policy('market_data.order_book_snapshots', INTERVAL '90 days', if_not_exists => TRUE);", + "SELECT add_retention_policy('market_data.trade_events', INTERVAL '90 days', if_not_exists => TRUE);", + "SELECT add_retention_policy('market_data.heatmap_data', INTERVAL '90 days', if_not_exists => TRUE);", + "SELECT add_retention_policy('market_data.ohlcv_data', INTERVAL '365 days', if_not_exists => TRUE);", + "SELECT add_retention_policy('market_data.exchange_status', INTERVAL '30 days', if_not_exists => TRUE);", + "SELECT add_retention_policy('market_data.system_metrics', INTERVAL '30 days', if_not_exists => TRUE);" + ] + + @staticmethod + def get_continuous_aggregate_queries() -> List[str]: + """Get queries to create continuous aggregates""" + return [ + # Hourly OHLCV aggregate + """ + CREATE MATERIALIZED VIEW IF NOT EXISTS market_data.hourly_ohlcv + WITH (timescaledb.continuous) AS + SELECT + symbol, + exchange, + time_bucket('1 hour', timestamp) AS hour, + first(price, timestamp) AS open_price, + max(price) AS high_price, + min(price) AS low_price, + last(price, timestamp) AS close_price, + sum(size) AS volume, + count(*) AS trade_count, + avg(price) AS vwap + FROM market_data.trade_events + GROUP BY symbol, exchange, hour + WITH NO DATA; + """, + + # Add refresh policy for continuous aggregate + """ + SELECT add_continuous_aggregate_policy('market_data.hourly_ohlcv', + start_offset => INTERVAL '3 hours', + end_offset => INTERVAL '1 hour', + schedule_interval => INTERVAL '1 hour', + if_not_exists => TRUE); + """ + ] + + @staticmethod + def get_view_creation_queries() -> List[str]: + """Get queries to create views""" + return [ + # Latest order books view + """ + CREATE OR REPLACE VIEW market_data.latest_order_books AS + SELECT DISTINCT ON (symbol, exchange) + symbol, + exchange, + timestamp, + bids, + asks, + mid_price, + spread, + bid_volume, + ask_volume + FROM market_data.order_book_snapshots + ORDER BY symbol, exchange, timestamp DESC; + """, + + # Latest heatmaps view + """ + CREATE OR REPLACE VIEW market_data.latest_heatmaps AS + SELECT DISTINCT ON (symbol, bucket_size, price_bucket, side) + symbol, + bucket_size, + price_bucket, + side, + timestamp, + volume, + exchange_count, + exchanges + FROM market_data.heatmap_data + ORDER BY symbol, bucket_size, price_bucket, side, timestamp DESC; + """ + ] + + @staticmethod + def get_all_creation_queries() -> List[str]: + """Get all schema creation queries in order""" + queries = [] + queries.extend(DatabaseSchema.get_schema_creation_queries()) + queries.extend(DatabaseSchema.get_hypertable_creation_queries()) + queries.extend(DatabaseSchema.get_index_creation_queries()) + queries.extend(DatabaseSchema.get_retention_policy_queries()) + queries.extend(DatabaseSchema.get_continuous_aggregate_queries()) + queries.extend(DatabaseSchema.get_view_creation_queries()) + return queries \ No newline at end of file diff --git a/COBY/storage/timescale_manager.py b/COBY/storage/timescale_manager.py new file mode 100644 index 0000000..9098e2e --- /dev/null +++ b/COBY/storage/timescale_manager.py @@ -0,0 +1,604 @@ +""" +TimescaleDB storage manager implementation. +""" + +import json +from datetime import datetime +from typing import List, Dict, Optional, Any +from ..interfaces.storage_manager import StorageManager +from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData, SystemMetrics, PriceLevel +from ..utils.logging import get_logger, set_correlation_id +from ..utils.exceptions import StorageError, ValidationError +from ..utils.timing import get_current_timestamp +from .connection_pool import db_pool +from .schema import DatabaseSchema + +logger = get_logger(__name__) + + +class TimescaleManager(StorageManager): + """TimescaleDB implementation of StorageManager interface""" + + def __init__(self): + self._schema_initialized = False + + async def initialize(self) -> None: + """Initialize the storage manager""" + await db_pool.initialize() + await self.setup_database_schema() + logger.info("TimescaleDB storage manager initialized") + + async def close(self) -> None: + """Close the storage manager""" + await db_pool.close() + logger.info("TimescaleDB storage manager closed") + + def setup_database_schema(self) -> None: + """Set up database schema and tables""" + async def _setup(): + if self._schema_initialized: + return + + try: + queries = DatabaseSchema.get_all_creation_queries() + + for query in queries: + try: + await db_pool.execute_command(query) + logger.debug(f"Executed schema query: {query[:50]}...") + except Exception as e: + # Log but continue - some queries might fail if already exists + logger.warning(f"Schema query failed (continuing): {e}") + + self._schema_initialized = True + logger.info("Database schema setup completed") + + except Exception as e: + logger.error(f"Failed to setup database schema: {e}") + raise StorageError(f"Schema setup failed: {e}", "SCHEMA_SETUP_ERROR") + + # Run async setup + import asyncio + if asyncio.get_event_loop().is_running(): + asyncio.create_task(_setup()) + else: + asyncio.run(_setup()) + + async def store_orderbook(self, data: OrderBookSnapshot) -> bool: + """Store order book snapshot to database""" + try: + set_correlation_id() + + # Convert price levels to JSON + bids_json = json.dumps([ + {"price": float(level.price), "size": float(level.size), "count": level.count} + for level in data.bids + ]) + asks_json = json.dumps([ + {"price": float(level.price), "size": float(level.size), "count": level.count} + for level in data.asks + ]) + + query = """ + INSERT INTO market_data.order_book_snapshots + (symbol, exchange, timestamp, bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + """ + + await db_pool.execute_command( + query, + data.symbol, + data.exchange, + data.timestamp, + bids_json, + asks_json, + data.sequence_id, + float(data.mid_price) if data.mid_price else None, + float(data.spread) if data.spread else None, + float(data.bid_volume), + float(data.ask_volume) + ) + + logger.debug(f"Stored order book: {data.symbol}@{data.exchange}") + return True + + except Exception as e: + logger.error(f"Failed to store order book: {e}") + return False + + async def store_trade(self, data: TradeEvent) -> bool: + """Store trade event to database""" + try: + set_correlation_id() + + query = """ + INSERT INTO market_data.trade_events + (symbol, exchange, timestamp, price, size, side, trade_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + """ + + await db_pool.execute_command( + query, + data.symbol, + data.exchange, + data.timestamp, + float(data.price), + float(data.size), + data.side, + data.trade_id + ) + + logger.debug(f"Stored trade: {data.symbol}@{data.exchange} - {data.trade_id}") + return True + + except Exception as e: + logger.error(f"Failed to store trade: {e}") + return False + + async def store_heatmap(self, data: HeatmapData) -> bool: + """Store heatmap data to database""" + try: + set_correlation_id() + + # Store each heatmap point + for point in data.data: + query = """ + INSERT INTO market_data.heatmap_data + (symbol, timestamp, bucket_size, price_bucket, volume, side, exchange_count, exchanges) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (timestamp, symbol, bucket_size, price_bucket, side) + DO UPDATE SET + volume = EXCLUDED.volume, + exchange_count = EXCLUDED.exchange_count, + exchanges = EXCLUDED.exchanges + """ + + await db_pool.execute_command( + query, + data.symbol, + data.timestamp, + float(data.bucket_size), + float(point.price), + float(point.volume), + point.side, + 1, # exchange_count - will be updated by aggregation + json.dumps([]) # exchanges - will be updated by aggregation + ) + + logger.debug(f"Stored heatmap: {data.symbol} with {len(data.data)} points") + return True + + except Exception as e: + logger.error(f"Failed to store heatmap: {e}") + return False + + async def store_metrics(self, data: SystemMetrics) -> bool: + """Store system metrics to database""" + try: + set_correlation_id() + + # Store multiple metrics + metrics = [ + ('cpu_usage', data.cpu_usage), + ('memory_usage', data.memory_usage), + ('disk_usage', data.disk_usage), + ('database_connections', data.database_connections), + ('redis_connections', data.redis_connections), + ('active_websockets', data.active_websockets), + ('messages_per_second', data.messages_per_second), + ('processing_latency', data.processing_latency) + ] + + query = """ + INSERT INTO market_data.system_metrics + (metric_name, timestamp, value, labels) + VALUES ($1, $2, $3, $4) + """ + + for metric_name, value in metrics: + await db_pool.execute_command( + query, + metric_name, + data.timestamp, + float(value), + json.dumps(data.network_io) + ) + + logger.debug("Stored system metrics") + return True + + except Exception as e: + logger.error(f"Failed to store metrics: {e}") + return False + + async def get_historical_orderbooks(self, symbol: str, exchange: str, + start: datetime, end: datetime, + limit: Optional[int] = None) -> List[OrderBookSnapshot]: + """Retrieve historical order book data""" + try: + query = """ + SELECT symbol, exchange, timestamp, bids, asks, sequence_id, mid_price, spread + FROM market_data.order_book_snapshots + WHERE symbol = $1 AND exchange = $2 AND timestamp >= $3 AND timestamp <= $4 + ORDER BY timestamp DESC + """ + + if limit: + query += f" LIMIT {limit}" + + rows = await db_pool.execute_query(query, symbol, exchange, start, end) + + orderbooks = [] + for row in rows: + # Parse JSON bid/ask data + bids_data = json.loads(row['bids']) + asks_data = json.loads(row['asks']) + + bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count')) + for b in bids_data] + asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count')) + for a in asks_data] + + orderbook = OrderBookSnapshot( + symbol=row['symbol'], + exchange=row['exchange'], + timestamp=row['timestamp'], + bids=bids, + asks=asks, + sequence_id=row['sequence_id'] + ) + orderbooks.append(orderbook) + + logger.debug(f"Retrieved {len(orderbooks)} historical order books") + return orderbooks + + except Exception as e: + logger.error(f"Failed to get historical order books: {e}") + return [] + + async def get_historical_trades(self, symbol: str, exchange: str, + start: datetime, end: datetime, + limit: Optional[int] = None) -> List[TradeEvent]: + """Retrieve historical trade data""" + try: + query = """ + SELECT symbol, exchange, timestamp, price, size, side, trade_id + FROM market_data.trade_events + WHERE symbol = $1 AND exchange = $2 AND timestamp >= $3 AND timestamp <= $4 + ORDER BY timestamp DESC + """ + + if limit: + query += f" LIMIT {limit}" + + rows = await db_pool.execute_query(query, symbol, exchange, start, end) + + trades = [] + for row in rows: + trade = TradeEvent( + symbol=row['symbol'], + exchange=row['exchange'], + timestamp=row['timestamp'], + price=float(row['price']), + size=float(row['size']), + side=row['side'], + trade_id=row['trade_id'] + ) + trades.append(trade) + + logger.debug(f"Retrieved {len(trades)} historical trades") + return trades + + except Exception as e: + logger.error(f"Failed to get historical trades: {e}") + return [] + + async def get_latest_orderbook(self, symbol: str, exchange: str) -> Optional[OrderBookSnapshot]: + """Get latest order book snapshot""" + try: + query = """ + SELECT symbol, exchange, timestamp, bids, asks, sequence_id + FROM market_data.order_book_snapshots + WHERE symbol = $1 AND exchange = $2 + ORDER BY timestamp DESC + LIMIT 1 + """ + + rows = await db_pool.execute_query(query, symbol, exchange) + + if not rows: + return None + + row = rows[0] + bids_data = json.loads(row['bids']) + asks_data = json.loads(row['asks']) + + bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count')) + for b in bids_data] + asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count')) + for a in asks_data] + + return OrderBookSnapshot( + symbol=row['symbol'], + exchange=row['exchange'], + timestamp=row['timestamp'], + bids=bids, + asks=asks, + sequence_id=row['sequence_id'] + ) + + except Exception as e: + logger.error(f"Failed to get latest order book: {e}") + return None + + async def get_latest_heatmap(self, symbol: str, bucket_size: float) -> Optional[HeatmapData]: + """Get latest heatmap data""" + try: + query = """ + SELECT price_bucket, volume, side, timestamp + FROM market_data.heatmap_data + WHERE symbol = $1 AND bucket_size = $2 + AND timestamp = ( + SELECT MAX(timestamp) + FROM market_data.heatmap_data + WHERE symbol = $1 AND bucket_size = $2 + ) + ORDER BY price_bucket + """ + + rows = await db_pool.execute_query(query, symbol, bucket_size) + + if not rows: + return None + + from ..models.core import HeatmapPoint + heatmap = HeatmapData( + symbol=symbol, + timestamp=rows[0]['timestamp'], + bucket_size=bucket_size + ) + + # Calculate max volume for intensity + max_volume = max(float(row['volume']) for row in rows) + + for row in rows: + volume = float(row['volume']) + intensity = volume / max_volume if max_volume > 0 else 0.0 + + point = HeatmapPoint( + price=float(row['price_bucket']), + volume=volume, + intensity=intensity, + side=row['side'] + ) + heatmap.data.append(point) + + return heatmap + + except Exception as e: + logger.error(f"Failed to get latest heatmap: {e}") + return None + + async def get_ohlcv_data(self, symbol: str, exchange: str, timeframe: str, + start: datetime, end: datetime) -> List[Dict[str, Any]]: + """Get OHLCV candlestick data""" + try: + query = """ + SELECT timestamp, open_price, high_price, low_price, close_price, volume, trade_count, vwap + FROM market_data.ohlcv_data + WHERE symbol = $1 AND exchange = $2 AND timeframe = $3 + AND timestamp >= $4 AND timestamp <= $5 + ORDER BY timestamp + """ + + rows = await db_pool.execute_query(query, symbol, exchange, timeframe, start, end) + + ohlcv_data = [] + for row in rows: + ohlcv_data.append({ + 'timestamp': row['timestamp'], + 'open': float(row['open_price']), + 'high': float(row['high_price']), + 'low': float(row['low_price']), + 'close': float(row['close_price']), + 'volume': float(row['volume']), + 'trade_count': row['trade_count'], + 'vwap': float(row['vwap']) if row['vwap'] else None + }) + + logger.debug(f"Retrieved {len(ohlcv_data)} OHLCV records") + return ohlcv_data + + except Exception as e: + logger.error(f"Failed to get OHLCV data: {e}") + return [] + + async def batch_store_orderbooks(self, data: List[OrderBookSnapshot]) -> int: + """Store multiple order book snapshots in batch""" + if not data: + return 0 + + try: + set_correlation_id() + + # Prepare batch data + batch_data = [] + for orderbook in data: + bids_json = json.dumps([ + {"price": float(level.price), "size": float(level.size), "count": level.count} + for level in orderbook.bids + ]) + asks_json = json.dumps([ + {"price": float(level.price), "size": float(level.size), "count": level.count} + for level in orderbook.asks + ]) + + batch_data.append(( + orderbook.symbol, + orderbook.exchange, + orderbook.timestamp, + bids_json, + asks_json, + orderbook.sequence_id, + float(orderbook.mid_price) if orderbook.mid_price else None, + float(orderbook.spread) if orderbook.spread else None, + float(orderbook.bid_volume), + float(orderbook.ask_volume) + )) + + query = """ + INSERT INTO market_data.order_book_snapshots + (symbol, exchange, timestamp, bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + """ + + await db_pool.execute_many(query, batch_data) + + logger.debug(f"Batch stored {len(data)} order books") + return len(data) + + except Exception as e: + logger.error(f"Failed to batch store order books: {e}") + return 0 + + async def batch_store_trades(self, data: List[TradeEvent]) -> int: + """Store multiple trade events in batch""" + if not data: + return 0 + + try: + set_correlation_id() + + # Prepare batch data + batch_data = [ + (trade.symbol, trade.exchange, trade.timestamp, float(trade.price), + float(trade.size), trade.side, trade.trade_id) + for trade in data + ] + + query = """ + INSERT INTO market_data.trade_events + (symbol, exchange, timestamp, price, size, side, trade_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + """ + + await db_pool.execute_many(query, batch_data) + + logger.debug(f"Batch stored {len(data)} trades") + return len(data) + + except Exception as e: + logger.error(f"Failed to batch store trades: {e}") + return 0 + + async def cleanup_old_data(self, retention_days: int) -> int: + """Clean up old data based on retention policy""" + try: + cutoff_time = get_current_timestamp().replace( + day=get_current_timestamp().day - retention_days + ) + + tables = [ + 'order_book_snapshots', + 'trade_events', + 'heatmap_data', + 'exchange_status', + 'system_metrics' + ] + + total_deleted = 0 + for table in tables: + query = f""" + DELETE FROM market_data.{table} + WHERE timestamp < $1 + """ + + result = await db_pool.execute_command(query, cutoff_time) + # Extract number from result like "DELETE 1234" + deleted = int(result.split()[-1]) if result.split()[-1].isdigit() else 0 + total_deleted += deleted + + logger.debug(f"Cleaned up {deleted} records from {table}") + + logger.info(f"Cleaned up {total_deleted} total records older than {retention_days} days") + return total_deleted + + except Exception as e: + logger.error(f"Failed to cleanup old data: {e}") + return 0 + + async def get_storage_stats(self) -> Dict[str, Any]: + """Get storage statistics""" + try: + stats = {} + + # Table sizes + size_query = """ + SELECT + schemaname, + tablename, + pg_size_pretty(pg_total_relation_size(schemaname||'.'||tablename)) as size, + pg_total_relation_size(schemaname||'.'||tablename) as size_bytes + FROM pg_tables + WHERE schemaname = 'market_data' + ORDER BY size_bytes DESC + """ + + size_rows = await db_pool.execute_query(size_query) + stats['table_sizes'] = [ + { + 'table': row['tablename'], + 'size': row['size'], + 'size_bytes': row['size_bytes'] + } + for row in size_rows + ] + + # Record counts + tables = ['order_book_snapshots', 'trade_events', 'heatmap_data', + 'ohlcv_data', 'exchange_status', 'system_metrics'] + + record_counts = {} + for table in tables: + count_query = f"SELECT COUNT(*) as count FROM market_data.{table}" + count_rows = await db_pool.execute_query(count_query) + record_counts[table] = count_rows[0]['count'] if count_rows else 0 + + stats['record_counts'] = record_counts + + # Connection pool stats + stats['connection_pool'] = await db_pool.get_pool_stats() + + return stats + + except Exception as e: + logger.error(f"Failed to get storage stats: {e}") + return {} + + async def health_check(self) -> bool: + """Check storage system health""" + try: + # Check database connection + if not await db_pool.health_check(): + return False + + # Check if tables exist + query = """ + SELECT COUNT(*) as count + FROM information_schema.tables + WHERE table_schema = 'market_data' + """ + + rows = await db_pool.execute_query(query) + table_count = rows[0]['count'] if rows else 0 + + if table_count < 6: # We expect 6 main tables + logger.warning(f"Expected 6 tables, found {table_count}") + return False + + logger.debug("Storage health check passed") + return True + + except Exception as e: + logger.error(f"Storage health check failed: {e}") + return False \ No newline at end of file diff --git a/COBY/test_integration.py b/COBY/test_integration.py new file mode 100644 index 0000000..b1b0e92 --- /dev/null +++ b/COBY/test_integration.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +""" +Integration test script for COBY system components. +Run this to test the TimescaleDB integration and basic functionality. +""" + +import asyncio +import sys +from datetime import datetime, timezone +from pathlib import Path + +# Add COBY to path +sys.path.insert(0, str(Path(__file__).parent)) + +from config import config +from storage.timescale_manager import TimescaleManager +from models.core import OrderBookSnapshot, TradeEvent, PriceLevel +from utils.logging import setup_logging, get_logger + +# Setup logging +setup_logging(level='INFO', console_output=True) +logger = get_logger(__name__) + + +async def test_database_connection(): + """Test basic database connectivity""" + logger.info("๐Ÿ”Œ Testing database connection...") + + try: + manager = TimescaleManager() + await manager.initialize() + + # Test health check + is_healthy = await manager.health_check() + if is_healthy: + logger.info("โœ… Database connection: HEALTHY") + else: + logger.error("โŒ Database connection: UNHEALTHY") + return False + + # Test storage stats + stats = await manager.get_storage_stats() + logger.info(f"๐Ÿ“Š Found {len(stats.get('table_sizes', []))} tables") + + for table_info in stats.get('table_sizes', []): + logger.info(f" ๐Ÿ“‹ {table_info['table']}: {table_info['size']}") + + await manager.close() + return True + + except Exception as e: + logger.error(f"โŒ Database test failed: {e}") + return False + + +async def test_data_storage(): + """Test storing and retrieving data""" + logger.info("๐Ÿ’พ Testing data storage operations...") + + try: + manager = TimescaleManager() + await manager.initialize() + + # Create test order book + test_orderbook = OrderBookSnapshot( + symbol="BTCUSDT", + exchange="test_exchange", + timestamp=datetime.now(timezone.utc), + bids=[ + PriceLevel(price=50000.0, size=1.5, count=3), + PriceLevel(price=49999.0, size=2.0, count=5) + ], + asks=[ + PriceLevel(price=50001.0, size=1.0, count=2), + PriceLevel(price=50002.0, size=1.5, count=4) + ], + sequence_id=12345 + ) + + # Test storing order book + result = await manager.store_orderbook(test_orderbook) + if result: + logger.info("โœ… Order book storage: SUCCESS") + else: + logger.error("โŒ Order book storage: FAILED") + return False + + # Test retrieving order book + retrieved = await manager.get_latest_orderbook("BTCUSDT", "test_exchange") + if retrieved: + logger.info(f"โœ… Order book retrieval: SUCCESS (mid_price: {retrieved.mid_price})") + else: + logger.error("โŒ Order book retrieval: FAILED") + return False + + # Create test trade + test_trade = TradeEvent( + symbol="BTCUSDT", + exchange="test_exchange", + timestamp=datetime.now(timezone.utc), + price=50000.5, + size=0.1, + side="buy", + trade_id="test_trade_123" + ) + + # Test storing trade + result = await manager.store_trade(test_trade) + if result: + logger.info("โœ… Trade storage: SUCCESS") + else: + logger.error("โŒ Trade storage: FAILED") + return False + + await manager.close() + return True + + except Exception as e: + logger.error(f"โŒ Data storage test failed: {e}") + return False + + +async def test_batch_operations(): + """Test batch storage operations""" + logger.info("๐Ÿ“ฆ Testing batch operations...") + + try: + manager = TimescaleManager() + await manager.initialize() + + # Create batch of order books + orderbooks = [] + for i in range(5): + orderbook = OrderBookSnapshot( + symbol="ETHUSDT", + exchange="test_exchange", + timestamp=datetime.now(timezone.utc), + bids=[PriceLevel(price=3000.0 + i, size=1.0)], + asks=[PriceLevel(price=3001.0 + i, size=1.0)], + sequence_id=i + ) + orderbooks.append(orderbook) + + # Test batch storage + result = await manager.batch_store_orderbooks(orderbooks) + if result == 5: + logger.info(f"โœ… Batch order book storage: SUCCESS ({result} records)") + else: + logger.error(f"โŒ Batch order book storage: PARTIAL ({result}/5 records)") + return False + + # Create batch of trades + trades = [] + for i in range(10): + trade = TradeEvent( + symbol="ETHUSDT", + exchange="test_exchange", + timestamp=datetime.now(timezone.utc), + price=3000.0 + (i * 0.1), + size=0.05, + side="buy" if i % 2 == 0 else "sell", + trade_id=f"batch_trade_{i}" + ) + trades.append(trade) + + # Test batch trade storage + result = await manager.batch_store_trades(trades) + if result == 10: + logger.info(f"โœ… Batch trade storage: SUCCESS ({result} records)") + else: + logger.error(f"โŒ Batch trade storage: PARTIAL ({result}/10 records)") + return False + + await manager.close() + return True + + except Exception as e: + logger.error(f"โŒ Batch operations test failed: {e}") + return False + + +async def test_configuration(): + """Test configuration system""" + logger.info("โš™๏ธ Testing configuration system...") + + try: + # Test database configuration + db_url = config.get_database_url() + logger.info(f"โœ… Database URL: {db_url.replace(config.database.password, '***')}") + + # Test Redis configuration + redis_url = config.get_redis_url() + logger.info(f"โœ… Redis URL: {redis_url.replace(config.redis.password, '***')}") + + # Test bucket sizes + btc_bucket = config.get_bucket_size('BTCUSDT') + eth_bucket = config.get_bucket_size('ETHUSDT') + logger.info(f"โœ… Bucket sizes: BTC=${btc_bucket}, ETH=${eth_bucket}") + + # Test configuration dict + config_dict = config.to_dict() + logger.info(f"โœ… Configuration loaded: {len(config_dict)} sections") + + return True + + except Exception as e: + logger.error(f"โŒ Configuration test failed: {e}") + return False + + +async def run_all_tests(): + """Run all integration tests""" + logger.info("๐Ÿš€ Starting COBY Integration Tests") + logger.info("=" * 50) + + tests = [ + ("Configuration", test_configuration), + ("Database Connection", test_database_connection), + ("Data Storage", test_data_storage), + ("Batch Operations", test_batch_operations) + ] + + results = [] + + for test_name, test_func in tests: + logger.info(f"\n๐Ÿงช Running {test_name} test...") + try: + result = await test_func() + results.append((test_name, result)) + if result: + logger.info(f"โœ… {test_name}: PASSED") + else: + logger.error(f"โŒ {test_name}: FAILED") + except Exception as e: + logger.error(f"โŒ {test_name}: ERROR - {e}") + results.append((test_name, False)) + + # Summary + logger.info("\n" + "=" * 50) + logger.info("๐Ÿ“‹ TEST SUMMARY") + logger.info("=" * 50) + + passed = sum(1 for _, result in results if result) + total = len(results) + + for test_name, result in results: + status = "โœ… PASSED" if result else "โŒ FAILED" + logger.info(f"{test_name:20} {status}") + + logger.info(f"\nOverall: {passed}/{total} tests passed") + + if passed == total: + logger.info("๐ŸŽ‰ All tests passed! System is ready.") + return True + else: + logger.error("โš ๏ธ Some tests failed. Check configuration and database connection.") + return False + + +if __name__ == "__main__": + print("COBY Integration Test Suite") + print("=" * 30) + + # Run tests + success = asyncio.run(run_all_tests()) + + if success: + print("\n๐ŸŽ‰ Integration tests completed successfully!") + print("The system is ready for the next development phase.") + sys.exit(0) + else: + print("\nโŒ Integration tests failed!") + print("Please check the logs and fix any issues before proceeding.") + sys.exit(1) \ No newline at end of file diff --git a/COBY/tests/__init__.py b/COBY/tests/__init__.py new file mode 100644 index 0000000..22555c0 --- /dev/null +++ b/COBY/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Test suite for the COBY system. +""" \ No newline at end of file diff --git a/COBY/tests/test_timescale_manager.py b/COBY/tests/test_timescale_manager.py new file mode 100644 index 0000000..d30af73 --- /dev/null +++ b/COBY/tests/test_timescale_manager.py @@ -0,0 +1,192 @@ +""" +Tests for TimescaleDB storage manager. +""" + +import pytest +import asyncio +from datetime import datetime, timezone +from ..storage.timescale_manager import TimescaleManager +from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel +from ..config import config + + +@pytest.fixture +async def storage_manager(): + """Create and initialize storage manager for testing""" + manager = TimescaleManager() + await manager.initialize() + yield manager + await manager.close() + + +@pytest.fixture +def sample_orderbook(): + """Create sample order book for testing""" + return OrderBookSnapshot( + symbol="BTCUSDT", + exchange="binance", + timestamp=datetime.now(timezone.utc), + bids=[ + PriceLevel(price=50000.0, size=1.5, count=3), + PriceLevel(price=49999.0, size=2.0, count=5) + ], + asks=[ + PriceLevel(price=50001.0, size=1.0, count=2), + PriceLevel(price=50002.0, size=1.5, count=4) + ], + sequence_id=12345 + ) + + +@pytest.fixture +def sample_trade(): + """Create sample trade for testing""" + return TradeEvent( + symbol="BTCUSDT", + exchange="binance", + timestamp=datetime.now(timezone.utc), + price=50000.5, + size=0.1, + side="buy", + trade_id="test_trade_123" + ) + + +class TestTimescaleManager: + """Test cases for TimescaleManager""" + + @pytest.mark.asyncio + async def test_health_check(self, storage_manager): + """Test storage health check""" + is_healthy = await storage_manager.health_check() + assert is_healthy is True + + @pytest.mark.asyncio + async def test_store_orderbook(self, storage_manager, sample_orderbook): + """Test storing order book snapshot""" + result = await storage_manager.store_orderbook(sample_orderbook) + assert result is True + + @pytest.mark.asyncio + async def test_store_trade(self, storage_manager, sample_trade): + """Test storing trade event""" + result = await storage_manager.store_trade(sample_trade) + assert result is True + + @pytest.mark.asyncio + async def test_get_latest_orderbook(self, storage_manager, sample_orderbook): + """Test retrieving latest order book""" + # Store the order book first + await storage_manager.store_orderbook(sample_orderbook) + + # Retrieve it + retrieved = await storage_manager.get_latest_orderbook( + sample_orderbook.symbol, + sample_orderbook.exchange + ) + + assert retrieved is not None + assert retrieved.symbol == sample_orderbook.symbol + assert retrieved.exchange == sample_orderbook.exchange + assert len(retrieved.bids) == len(sample_orderbook.bids) + assert len(retrieved.asks) == len(sample_orderbook.asks) + + @pytest.mark.asyncio + async def test_batch_store_orderbooks(self, storage_manager): + """Test batch storing order books""" + orderbooks = [] + for i in range(5): + orderbook = OrderBookSnapshot( + symbol="ETHUSDT", + exchange="binance", + timestamp=datetime.now(timezone.utc), + bids=[PriceLevel(price=3000.0 + i, size=1.0)], + asks=[PriceLevel(price=3001.0 + i, size=1.0)], + sequence_id=i + ) + orderbooks.append(orderbook) + + result = await storage_manager.batch_store_orderbooks(orderbooks) + assert result == 5 + + @pytest.mark.asyncio + async def test_batch_store_trades(self, storage_manager): + """Test batch storing trades""" + trades = [] + for i in range(5): + trade = TradeEvent( + symbol="ETHUSDT", + exchange="binance", + timestamp=datetime.now(timezone.utc), + price=3000.0 + i, + size=0.1, + side="buy" if i % 2 == 0 else "sell", + trade_id=f"test_trade_{i}" + ) + trades.append(trade) + + result = await storage_manager.batch_store_trades(trades) + assert result == 5 + + @pytest.mark.asyncio + async def test_get_storage_stats(self, storage_manager): + """Test getting storage statistics""" + stats = await storage_manager.get_storage_stats() + + assert isinstance(stats, dict) + assert 'table_sizes' in stats + assert 'record_counts' in stats + assert 'connection_pool' in stats + + @pytest.mark.asyncio + async def test_historical_data_retrieval(self, storage_manager, sample_orderbook, sample_trade): + """Test retrieving historical data""" + # Store some data first + await storage_manager.store_orderbook(sample_orderbook) + await storage_manager.store_trade(sample_trade) + + # Define time range + start_time = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) + end_time = datetime.now(timezone.utc).replace(hour=23, minute=59, second=59, microsecond=999999) + + # Retrieve historical order books + orderbooks = await storage_manager.get_historical_orderbooks( + sample_orderbook.symbol, + sample_orderbook.exchange, + start_time, + end_time, + limit=10 + ) + + assert isinstance(orderbooks, list) + + # Retrieve historical trades + trades = await storage_manager.get_historical_trades( + sample_trade.symbol, + sample_trade.exchange, + start_time, + end_time, + limit=10 + ) + + assert isinstance(trades, list) + + +if __name__ == "__main__": + # Run a simple test + async def simple_test(): + manager = TimescaleManager() + await manager.initialize() + + # Test health check + is_healthy = await manager.health_check() + print(f"Health check: {'PASSED' if is_healthy else 'FAILED'}") + + # Test storage stats + stats = await manager.get_storage_stats() + print(f"Storage stats: {len(stats)} categories") + + await manager.close() + print("Simple test completed") + + asyncio.run(simple_test()) \ No newline at end of file