cob integration scaffold
This commit is contained in:
@ -7,15 +7,22 @@
|
|||||||
- Create directory structure in `.\COBY` subfolder for the multi-exchange data aggregation system
|
- Create directory structure in `.\COBY` subfolder for the multi-exchange data aggregation system
|
||||||
- Define base interfaces and data models for exchange connectors, data processing, and storage
|
- Define base interfaces and data models for exchange connectors, data processing, and storage
|
||||||
- Implement configuration management system with environment variable support
|
- Implement configuration management system with environment variable support
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- _Requirements: 1.1, 6.1, 7.3_
|
- _Requirements: 1.1, 6.1, 7.3_
|
||||||
|
|
||||||
- [ ] 2. Implement TimescaleDB integration and database schema
|
- [ ] 2. Implement TimescaleDB integration and database schema
|
||||||
- Create TimescaleDB connection manager with connection pooling
|
- Create TimescaleDB connection manager with connection pooling
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- Implement database schema creation with hypertables for time-series optimization
|
- Implement database schema creation with hypertables for time-series optimization
|
||||||
- Write database operations for storing order book snapshots and trade events
|
- Write database operations for storing order book snapshots and trade events
|
||||||
- Create database migration system for schema updates
|
- Create database migration system for schema updates
|
||||||
- _Requirements: 3.1, 3.2, 3.3, 3.4_
|
- _Requirements: 3.1, 3.2, 3.3, 3.4_
|
||||||
|
|
||||||
|
|
||||||
- [ ] 3. Create base exchange connector framework
|
- [ ] 3. Create base exchange connector framework
|
||||||
- Implement abstract base class for exchange WebSocket connectors
|
- Implement abstract base class for exchange WebSocket connectors
|
||||||
- Create connection management with exponential backoff and circuit breaker patterns
|
- Create connection management with exponential backoff and circuit breaker patterns
|
||||||
|
13
COBY/connectors/__init__.py
Normal file
13
COBY/connectors/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Exchange connector implementations for the COBY system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base_connector import BaseExchangeConnector
|
||||||
|
from .connection_manager import ConnectionManager
|
||||||
|
from .circuit_breaker import CircuitBreaker
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseExchangeConnector',
|
||||||
|
'ConnectionManager',
|
||||||
|
'CircuitBreaker'
|
||||||
|
]
|
383
COBY/connectors/base_connector.py
Normal file
383
COBY/connectors/base_connector.py
Normal file
@ -0,0 +1,383 @@
|
|||||||
|
"""
|
||||||
|
Base exchange connector implementation with connection management and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import websockets
|
||||||
|
from typing import Dict, List, Optional, Callable, Any
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from ..interfaces.exchange_connector import ExchangeConnector
|
||||||
|
from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
|
||||||
|
from ..utils.logging import get_logger, set_correlation_id
|
||||||
|
from ..utils.exceptions import ConnectionError, ValidationError
|
||||||
|
from ..utils.timing import get_current_timestamp
|
||||||
|
from .connection_manager import ConnectionManager
|
||||||
|
from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExchangeConnector(ExchangeConnector):
|
||||||
|
"""
|
||||||
|
Base implementation of exchange connector with common functionality.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- WebSocket connection management
|
||||||
|
- Exponential backoff retry logic
|
||||||
|
- Circuit breaker pattern
|
||||||
|
- Health monitoring
|
||||||
|
- Message handling framework
|
||||||
|
- Subscription management
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, exchange_name: str, websocket_url: str):
|
||||||
|
"""
|
||||||
|
Initialize base exchange connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_name: Name of the exchange
|
||||||
|
websocket_url: WebSocket URL for the exchange
|
||||||
|
"""
|
||||||
|
super().__init__(exchange_name)
|
||||||
|
|
||||||
|
self.websocket_url = websocket_url
|
||||||
|
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
|
||||||
|
self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types]
|
||||||
|
self.message_handlers: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
# Connection management
|
||||||
|
self.connection_manager = ConnectionManager(
|
||||||
|
name=f"{exchange_name}_connector",
|
||||||
|
max_retries=10,
|
||||||
|
initial_delay=1.0,
|
||||||
|
max_delay=300.0,
|
||||||
|
health_check_interval=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Circuit breaker
|
||||||
|
self.circuit_breaker = CircuitBreaker(
|
||||||
|
failure_threshold=5,
|
||||||
|
recovery_timeout=60,
|
||||||
|
expected_exception=Exception,
|
||||||
|
name=f"{exchange_name}_circuit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.message_count = 0
|
||||||
|
self.error_count = 0
|
||||||
|
self.last_message_time: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Setup callbacks
|
||||||
|
self.connection_manager.on_connect = self._on_connect
|
||||||
|
self.connection_manager.on_disconnect = self._on_disconnect
|
||||||
|
self.connection_manager.on_error = self._on_error
|
||||||
|
self.connection_manager.on_health_check = self._health_check
|
||||||
|
|
||||||
|
# Message processing
|
||||||
|
self._message_queue = asyncio.Queue(maxsize=10000)
|
||||||
|
self._message_processor_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
logger.info(f"Base connector initialized for {exchange_name}")
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
"""Establish connection to the exchange WebSocket"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}")
|
||||||
|
|
||||||
|
return await self.connection_manager.connect(self._establish_websocket_connection)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
|
||||||
|
self._notify_status_callbacks(ConnectionStatus.ERROR)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from the exchange WebSocket"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
logger.info(f"Disconnecting from {self.exchange_name}")
|
||||||
|
|
||||||
|
await self.connection_manager.disconnect(self._close_websocket_connection)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during disconnect from {self.exchange_name}: {e}")
|
||||||
|
|
||||||
|
async def _establish_websocket_connection(self) -> None:
|
||||||
|
"""Establish WebSocket connection"""
|
||||||
|
try:
|
||||||
|
# Use circuit breaker for connection
|
||||||
|
self.websocket = await self.circuit_breaker.call_async(
|
||||||
|
websockets.connect,
|
||||||
|
self.websocket_url,
|
||||||
|
ping_interval=20,
|
||||||
|
ping_timeout=10,
|
||||||
|
close_timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"WebSocket connected to {self.exchange_name}")
|
||||||
|
|
||||||
|
# Start message processing
|
||||||
|
await self._start_message_processing()
|
||||||
|
|
||||||
|
except CircuitBreakerOpenError as e:
|
||||||
|
logger.error(f"Circuit breaker open for {self.exchange_name}: {e}")
|
||||||
|
raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}")
|
||||||
|
raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED")
|
||||||
|
|
||||||
|
async def _close_websocket_connection(self) -> None:
|
||||||
|
"""Close WebSocket connection"""
|
||||||
|
try:
|
||||||
|
# Stop message processing
|
||||||
|
await self._stop_message_processing()
|
||||||
|
|
||||||
|
# Close WebSocket
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.websocket = None
|
||||||
|
|
||||||
|
logger.info(f"WebSocket disconnected from {self.exchange_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}")
|
||||||
|
|
||||||
|
async def _start_message_processing(self) -> None:
|
||||||
|
"""Start message processing tasks"""
|
||||||
|
if self._message_processor_task:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start message processor
|
||||||
|
self._message_processor_task = asyncio.create_task(self._message_processor())
|
||||||
|
|
||||||
|
# Start message receiver
|
||||||
|
asyncio.create_task(self._message_receiver())
|
||||||
|
|
||||||
|
logger.debug(f"Message processing started for {self.exchange_name}")
|
||||||
|
|
||||||
|
async def _stop_message_processing(self) -> None:
|
||||||
|
"""Stop message processing tasks"""
|
||||||
|
if self._message_processor_task:
|
||||||
|
self._message_processor_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._message_processor_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._message_processor_task = None
|
||||||
|
|
||||||
|
logger.debug(f"Message processing stopped for {self.exchange_name}")
|
||||||
|
|
||||||
|
async def _message_receiver(self) -> None:
|
||||||
|
"""Receive messages from WebSocket"""
|
||||||
|
try:
|
||||||
|
while self.websocket and not self.websocket.closed:
|
||||||
|
try:
|
||||||
|
message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0)
|
||||||
|
|
||||||
|
# Queue message for processing
|
||||||
|
try:
|
||||||
|
self._message_queue.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning(f"Message queue full for {self.exchange_name}, dropping message")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send ping to keep connection alive
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.ping()
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.warning(f"WebSocket connection closed for {self.exchange_name}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error receiving message from {self.exchange_name}: {e}")
|
||||||
|
self.error_count += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Message receiver error for {self.exchange_name}: {e}")
|
||||||
|
finally:
|
||||||
|
# Mark as disconnected
|
||||||
|
self.connection_manager.is_connected = False
|
||||||
|
|
||||||
|
async def _message_processor(self) -> None:
|
||||||
|
"""Process messages from the queue"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get message from queue
|
||||||
|
message = await self._message_queue.get()
|
||||||
|
|
||||||
|
# Process message
|
||||||
|
await self._process_message(message)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self.message_count += 1
|
||||||
|
self.last_message_time = get_current_timestamp()
|
||||||
|
|
||||||
|
# Mark task as done
|
||||||
|
self._message_queue.task_done()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message for {self.exchange_name}: {e}")
|
||||||
|
self.error_count += 1
|
||||||
|
|
||||||
|
async def _process_message(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Process incoming WebSocket message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Raw message string
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse JSON message
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
# Determine message type and route to appropriate handler
|
||||||
|
message_type = self._get_message_type(data)
|
||||||
|
|
||||||
|
if message_type in self.message_handlers:
|
||||||
|
await self.message_handlers[message_type](data)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message from {self.exchange_name}: {e}")
|
||||||
|
|
||||||
|
def _get_message_type(self, data: Dict) -> str:
|
||||||
|
"""
|
||||||
|
Determine message type from message data.
|
||||||
|
Override in subclasses for exchange-specific logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Parsed message data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Message type identifier
|
||||||
|
"""
|
||||||
|
# Default implementation - override in subclasses
|
||||||
|
return data.get('type', 'unknown')
|
||||||
|
|
||||||
|
async def _send_message(self, message: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Send message to WebSocket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to send
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.websocket or self.websocket.closed:
|
||||||
|
logger.warning(f"Cannot send message to {self.exchange_name}: not connected")
|
||||||
|
return False
|
||||||
|
|
||||||
|
message_str = json.dumps(message)
|
||||||
|
await self.websocket.send(message_str)
|
||||||
|
|
||||||
|
logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending message to {self.exchange_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Callback handlers
|
||||||
|
async def _on_connect(self) -> None:
|
||||||
|
"""Handle successful connection"""
|
||||||
|
self._notify_status_callbacks(ConnectionStatus.CONNECTED)
|
||||||
|
|
||||||
|
# Resubscribe to all previous subscriptions
|
||||||
|
await self._resubscribe_all()
|
||||||
|
|
||||||
|
async def _on_disconnect(self) -> None:
|
||||||
|
"""Handle disconnection"""
|
||||||
|
self._notify_status_callbacks(ConnectionStatus.DISCONNECTED)
|
||||||
|
|
||||||
|
async def _on_error(self, error: Exception) -> None:
|
||||||
|
"""Handle connection error"""
|
||||||
|
logger.error(f"Connection error for {self.exchange_name}: {error}")
|
||||||
|
self._notify_status_callbacks(ConnectionStatus.ERROR)
|
||||||
|
|
||||||
|
async def _health_check(self) -> bool:
|
||||||
|
"""Perform health check"""
|
||||||
|
try:
|
||||||
|
if not self.websocket or self.websocket.closed:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if we've received messages recently
|
||||||
|
if self.last_message_time:
|
||||||
|
time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds()
|
||||||
|
if time_since_last_message > 60: # No messages for 60 seconds
|
||||||
|
logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Send ping
|
||||||
|
await self.websocket.ping()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Health check failed for {self.exchange_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _resubscribe_all(self) -> None:
|
||||||
|
"""Resubscribe to all previous subscriptions after reconnection"""
|
||||||
|
for symbol, subscription_types in self.subscriptions.items():
|
||||||
|
for sub_type in subscription_types:
|
||||||
|
try:
|
||||||
|
if sub_type == 'orderbook':
|
||||||
|
await self.subscribe_orderbook(symbol)
|
||||||
|
elif sub_type == 'trades':
|
||||||
|
await self.subscribe_trades(symbol)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}")
|
||||||
|
|
||||||
|
# Abstract methods that must be implemented by subclasses
|
||||||
|
async def subscribe_orderbook(self, symbol: str) -> None:
|
||||||
|
"""Subscribe to order book updates - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement subscribe_orderbook")
|
||||||
|
|
||||||
|
async def subscribe_trades(self, symbol: str) -> None:
|
||||||
|
"""Subscribe to trade updates - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement subscribe_trades")
|
||||||
|
|
||||||
|
async def unsubscribe_orderbook(self, symbol: str) -> None:
|
||||||
|
"""Unsubscribe from order book updates - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement unsubscribe_orderbook")
|
||||||
|
|
||||||
|
async def unsubscribe_trades(self, symbol: str) -> None:
|
||||||
|
"""Unsubscribe from trade updates - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement unsubscribe_trades")
|
||||||
|
|
||||||
|
async def get_symbols(self) -> List[str]:
|
||||||
|
"""Get available symbols - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement get_symbols")
|
||||||
|
|
||||||
|
def normalize_symbol(self, symbol: str) -> str:
|
||||||
|
"""Normalize symbol format - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement normalize_symbol")
|
||||||
|
|
||||||
|
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
|
||||||
|
"""Get order book snapshot - must be implemented by subclasses"""
|
||||||
|
raise NotImplementedError("Subclasses must implement get_orderbook_snapshot")
|
||||||
|
|
||||||
|
# Utility methods
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get connector statistics"""
|
||||||
|
return {
|
||||||
|
'exchange': self.exchange_name,
|
||||||
|
'connection_status': self.get_connection_status().value,
|
||||||
|
'is_connected': self.is_connected,
|
||||||
|
'message_count': self.message_count,
|
||||||
|
'error_count': self.error_count,
|
||||||
|
'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None,
|
||||||
|
'subscriptions': dict(self.subscriptions),
|
||||||
|
'connection_manager': self.connection_manager.get_stats(),
|
||||||
|
'circuit_breaker': self.circuit_breaker.get_stats(),
|
||||||
|
'queue_size': self._message_queue.qsize()
|
||||||
|
}
|
206
COBY/connectors/circuit_breaker.py
Normal file
206
COBY/connectors/circuit_breaker.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
Circuit breaker pattern implementation for exchange connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Callable, Any
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitState(Enum):
|
||||||
|
"""Circuit breaker states"""
|
||||||
|
CLOSED = "closed" # Normal operation
|
||||||
|
OPEN = "open" # Circuit is open, calls fail fast
|
||||||
|
HALF_OPEN = "half_open" # Testing if service is back
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreaker:
|
||||||
|
"""
|
||||||
|
Circuit breaker to prevent cascading failures in exchange connections.
|
||||||
|
|
||||||
|
States:
|
||||||
|
- CLOSED: Normal operation, requests pass through
|
||||||
|
- OPEN: Circuit is open, requests fail immediately
|
||||||
|
- HALF_OPEN: Testing if service is back, limited requests allowed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
failure_threshold: int = 5,
|
||||||
|
recovery_timeout: int = 60,
|
||||||
|
expected_exception: type = Exception,
|
||||||
|
name: str = "CircuitBreaker"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failure_threshold: Number of failures before opening circuit
|
||||||
|
recovery_timeout: Time in seconds before attempting recovery
|
||||||
|
expected_exception: Exception type that triggers circuit breaker
|
||||||
|
name: Name for logging purposes
|
||||||
|
"""
|
||||||
|
self.failure_threshold = failure_threshold
|
||||||
|
self.recovery_timeout = recovery_timeout
|
||||||
|
self.expected_exception = expected_exception
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
# State tracking
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._failure_count = 0
|
||||||
|
self._last_failure_time: Optional[float] = None
|
||||||
|
self._next_attempt_time: Optional[float] = None
|
||||||
|
|
||||||
|
logger.info(f"Circuit breaker '{name}' initialized with threshold={failure_threshold}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> CircuitState:
|
||||||
|
"""Get current circuit state"""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failure_count(self) -> int:
|
||||||
|
"""Get current failure count"""
|
||||||
|
return self._failure_count
|
||||||
|
|
||||||
|
def _should_attempt_reset(self) -> bool:
|
||||||
|
"""Check if we should attempt to reset the circuit"""
|
||||||
|
if self._state != CircuitState.OPEN:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._next_attempt_time is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return time.time() >= self._next_attempt_time
|
||||||
|
|
||||||
|
def _on_success(self) -> None:
|
||||||
|
"""Handle successful operation"""
|
||||||
|
if self._state == CircuitState.HALF_OPEN:
|
||||||
|
logger.info(f"Circuit breaker '{self.name}' reset to CLOSED after successful test")
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
|
||||||
|
self._failure_count = 0
|
||||||
|
self._last_failure_time = None
|
||||||
|
self._next_attempt_time = None
|
||||||
|
|
||||||
|
def _on_failure(self) -> None:
|
||||||
|
"""Handle failed operation"""
|
||||||
|
self._failure_count += 1
|
||||||
|
self._last_failure_time = time.time()
|
||||||
|
|
||||||
|
if self._state == CircuitState.HALF_OPEN:
|
||||||
|
# Failed during test, go back to OPEN
|
||||||
|
logger.warning(f"Circuit breaker '{self.name}' failed during test, returning to OPEN")
|
||||||
|
self._state = CircuitState.OPEN
|
||||||
|
self._next_attempt_time = time.time() + self.recovery_timeout
|
||||||
|
elif self._failure_count >= self.failure_threshold:
|
||||||
|
# Too many failures, open the circuit
|
||||||
|
logger.error(
|
||||||
|
f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures"
|
||||||
|
)
|
||||||
|
self._state = CircuitState.OPEN
|
||||||
|
self._next_attempt_time = time.time() + self.recovery_timeout
|
||||||
|
|
||||||
|
def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute function with circuit breaker protection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to execute
|
||||||
|
*args: Function arguments
|
||||||
|
**kwargs: Function keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CircuitBreakerOpenError: When circuit is open
|
||||||
|
Original exception: When function fails
|
||||||
|
"""
|
||||||
|
# Check if we should attempt reset
|
||||||
|
if self._should_attempt_reset():
|
||||||
|
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
|
||||||
|
self._state = CircuitState.HALF_OPEN
|
||||||
|
|
||||||
|
# Fail fast if circuit is open
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
raise CircuitBreakerOpenError(
|
||||||
|
f"Circuit breaker '{self.name}' is OPEN. "
|
||||||
|
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute the function
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
self._on_success()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except self.expected_exception as e:
|
||||||
|
self._on_failure()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def call_async(self, func: Callable, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute async function with circuit breaker protection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Async function to execute
|
||||||
|
*args: Function arguments
|
||||||
|
**kwargs: Function keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CircuitBreakerOpenError: When circuit is open
|
||||||
|
Original exception: When function fails
|
||||||
|
"""
|
||||||
|
# Check if we should attempt reset
|
||||||
|
if self._should_attempt_reset():
|
||||||
|
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
|
||||||
|
self._state = CircuitState.HALF_OPEN
|
||||||
|
|
||||||
|
# Fail fast if circuit is open
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
raise CircuitBreakerOpenError(
|
||||||
|
f"Circuit breaker '{self.name}' is OPEN. "
|
||||||
|
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute the async function
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
self._on_success()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except self.expected_exception as e:
|
||||||
|
self._on_failure()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Manually reset the circuit breaker"""
|
||||||
|
logger.info(f"Circuit breaker '{self.name}' manually reset")
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._failure_count = 0
|
||||||
|
self._last_failure_time = None
|
||||||
|
self._next_attempt_time = None
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get circuit breaker statistics"""
|
||||||
|
return {
|
||||||
|
'name': self.name,
|
||||||
|
'state': self._state.value,
|
||||||
|
'failure_count': self._failure_count,
|
||||||
|
'failure_threshold': self.failure_threshold,
|
||||||
|
'last_failure_time': self._last_failure_time,
|
||||||
|
'next_attempt_time': self._next_attempt_time,
|
||||||
|
'recovery_timeout': self.recovery_timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerOpenError(Exception):
|
||||||
|
"""Exception raised when circuit breaker is open"""
|
||||||
|
pass
|
271
COBY/connectors/connection_manager.py
Normal file
271
COBY/connectors/connection_manager.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
"""
|
||||||
|
Connection management with exponential backoff and retry logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from typing import Optional, Callable, Any
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
from ..utils.exceptions import ConnectionError
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExponentialBackoff:
|
||||||
|
"""Exponential backoff strategy for connection retries"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 300.0,
|
||||||
|
multiplier: float = 2.0,
|
||||||
|
jitter: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_delay: Initial delay in seconds
|
||||||
|
max_delay: Maximum delay in seconds
|
||||||
|
multiplier: Backoff multiplier
|
||||||
|
jitter: Whether to add random jitter
|
||||||
|
"""
|
||||||
|
self.initial_delay = initial_delay
|
||||||
|
self.max_delay = max_delay
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.jitter = jitter
|
||||||
|
self.current_delay = initial_delay
|
||||||
|
self.attempt_count = 0
|
||||||
|
|
||||||
|
def get_delay(self) -> float:
|
||||||
|
"""Get next delay value"""
|
||||||
|
delay = min(self.current_delay, self.max_delay)
|
||||||
|
|
||||||
|
# Add jitter to prevent thundering herd
|
||||||
|
if self.jitter:
|
||||||
|
delay = delay * (0.5 + random.random() * 0.5)
|
||||||
|
|
||||||
|
# Update for next attempt
|
||||||
|
self.current_delay *= self.multiplier
|
||||||
|
self.attempt_count += 1
|
||||||
|
|
||||||
|
return delay
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset backoff to initial state"""
|
||||||
|
self.current_delay = self.initial_delay
|
||||||
|
self.attempt_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""
|
||||||
|
Manages connection lifecycle with retry logic and health monitoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
max_retries: int = 10,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 300.0,
|
||||||
|
health_check_interval: int = 30
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize connection manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Connection name for logging
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
initial_delay: Initial retry delay in seconds
|
||||||
|
max_delay: Maximum retry delay in seconds
|
||||||
|
health_check_interval: Health check interval in seconds
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.health_check_interval = health_check_interval
|
||||||
|
|
||||||
|
self.backoff = ExponentialBackoff(initial_delay, max_delay)
|
||||||
|
self.is_connected = False
|
||||||
|
self.connection_attempts = 0
|
||||||
|
self.last_error: Optional[Exception] = None
|
||||||
|
self.health_check_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.on_connect: Optional[Callable] = None
|
||||||
|
self.on_disconnect: Optional[Callable] = None
|
||||||
|
self.on_error: Optional[Callable] = None
|
||||||
|
self.on_health_check: Optional[Callable] = None
|
||||||
|
|
||||||
|
logger.info(f"Connection manager '{name}' initialized")
|
||||||
|
|
||||||
|
async def connect(self, connect_func: Callable) -> bool:
|
||||||
|
"""
|
||||||
|
Attempt to establish connection with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connect_func: Async function that establishes the connection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
|
self.connection_attempts = 0
|
||||||
|
self.backoff.reset()
|
||||||
|
|
||||||
|
while self.connection_attempts < self.max_retries:
|
||||||
|
try:
|
||||||
|
logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})")
|
||||||
|
|
||||||
|
# Attempt connection
|
||||||
|
await connect_func()
|
||||||
|
|
||||||
|
# Connection successful
|
||||||
|
self.is_connected = True
|
||||||
|
self.connection_attempts = 0
|
||||||
|
self.last_error = None
|
||||||
|
self.backoff.reset()
|
||||||
|
|
||||||
|
logger.info(f"Connection '{self.name}' established successfully")
|
||||||
|
|
||||||
|
# Start health check
|
||||||
|
await self._start_health_check()
|
||||||
|
|
||||||
|
# Notify success
|
||||||
|
if self.on_connect:
|
||||||
|
try:
|
||||||
|
await self.on_connect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in connect callback: {e}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.connection_attempts += 1
|
||||||
|
self.last_error = e
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify error
|
||||||
|
if self.on_error:
|
||||||
|
try:
|
||||||
|
await self.on_error(e)
|
||||||
|
except Exception as callback_error:
|
||||||
|
logger.warning(f"Error in error callback: {callback_error}")
|
||||||
|
|
||||||
|
# Check if we should retry
|
||||||
|
if self.connection_attempts >= self.max_retries:
|
||||||
|
logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait before retry
|
||||||
|
delay = self.backoff.get_delay()
|
||||||
|
logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect and cleanup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disconnect_func: Optional async function to handle disconnection
|
||||||
|
"""
|
||||||
|
logger.info(f"Disconnecting '{self.name}'")
|
||||||
|
|
||||||
|
# Stop health check
|
||||||
|
await self._stop_health_check()
|
||||||
|
|
||||||
|
# Execute disconnect function
|
||||||
|
if disconnect_func:
|
||||||
|
try:
|
||||||
|
await disconnect_func()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during disconnect: {e}")
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
# Notify disconnect
|
||||||
|
if self.on_disconnect:
|
||||||
|
try:
|
||||||
|
await self.on_disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in disconnect callback: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Connection '{self.name}' disconnected")
|
||||||
|
|
||||||
|
async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Reconnect by disconnecting first then connecting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connect_func: Async function that establishes the connection
|
||||||
|
disconnect_func: Optional async function to handle disconnection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if reconnection successful, False otherwise
|
||||||
|
"""
|
||||||
|
logger.info(f"Reconnecting '{self.name}'")
|
||||||
|
|
||||||
|
# Disconnect first
|
||||||
|
await self.disconnect(disconnect_func)
|
||||||
|
|
||||||
|
# Wait a bit before reconnecting
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
|
||||||
|
# Attempt to connect
|
||||||
|
return await self.connect(connect_func)
|
||||||
|
|
||||||
|
async def _start_health_check(self) -> None:
|
||||||
|
"""Start periodic health check"""
|
||||||
|
if self.health_check_task:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.health_check_task = asyncio.create_task(self._health_check_loop())
|
||||||
|
logger.debug(f"Health check started for '{self.name}'")
|
||||||
|
|
||||||
|
async def _stop_health_check(self) -> None:
|
||||||
|
"""Stop health check"""
|
||||||
|
if self.health_check_task:
|
||||||
|
self.health_check_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.health_check_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self.health_check_task = None
|
||||||
|
logger.debug(f"Health check stopped for '{self.name}'")
|
||||||
|
|
||||||
|
async def _health_check_loop(self) -> None:
|
||||||
|
"""Health check loop"""
|
||||||
|
while self.is_connected:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.health_check_interval)
|
||||||
|
|
||||||
|
if self.on_health_check:
|
||||||
|
is_healthy = await self.on_health_check()
|
||||||
|
if not is_healthy:
|
||||||
|
logger.warning(f"Health check failed for '{self.name}'")
|
||||||
|
self.is_connected = False
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Health check error for '{self.name}': {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get connection statistics"""
|
||||||
|
return {
|
||||||
|
'name': self.name,
|
||||||
|
'is_connected': self.is_connected,
|
||||||
|
'connection_attempts': self.connection_attempts,
|
||||||
|
'max_retries': self.max_retries,
|
||||||
|
'current_delay': self.backoff.current_delay,
|
||||||
|
'backoff_attempts': self.backoff.attempt_count,
|
||||||
|
'last_error': str(self.last_error) if self.last_error else None,
|
||||||
|
'health_check_active': self.health_check_task is not None
|
||||||
|
}
|
34
COBY/requirements.txt
Normal file
34
COBY/requirements.txt
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Core dependencies for COBY system
|
||||||
|
asyncpg>=0.29.0 # PostgreSQL/TimescaleDB async driver
|
||||||
|
redis>=5.0.0 # Redis client
|
||||||
|
websockets>=12.0 # WebSocket client library
|
||||||
|
aiohttp>=3.9.0 # Async HTTP client/server
|
||||||
|
fastapi>=0.104.0 # API framework
|
||||||
|
uvicorn>=0.24.0 # ASGI server
|
||||||
|
pydantic>=2.5.0 # Data validation
|
||||||
|
python-multipart>=0.0.6 # Form data parsing
|
||||||
|
|
||||||
|
# Data processing
|
||||||
|
pandas>=2.1.0 # Data manipulation
|
||||||
|
numpy>=1.24.0 # Numerical computing
|
||||||
|
scipy>=1.11.0 # Scientific computing
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
python-dotenv>=1.0.0 # Environment variable loading
|
||||||
|
structlog>=23.2.0 # Structured logging
|
||||||
|
click>=8.1.0 # CLI framework
|
||||||
|
rich>=13.7.0 # Rich text and beautiful formatting
|
||||||
|
|
||||||
|
# Development dependencies
|
||||||
|
pytest>=7.4.0 # Testing framework
|
||||||
|
pytest-asyncio>=0.21.0 # Async testing
|
||||||
|
pytest-cov>=4.1.0 # Coverage reporting
|
||||||
|
black>=23.11.0 # Code formatting
|
||||||
|
isort>=5.12.0 # Import sorting
|
||||||
|
flake8>=6.1.0 # Linting
|
||||||
|
mypy>=1.7.0 # Type checking
|
||||||
|
|
||||||
|
# Optional dependencies for enhanced features
|
||||||
|
prometheus-client>=0.19.0 # Metrics collection
|
||||||
|
grafana-api>=1.0.3 # Grafana integration
|
||||||
|
psutil>=5.9.0 # System monitoring
|
11
COBY/storage/__init__.py
Normal file
11
COBY/storage/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Storage layer for the COBY system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .timescale_manager import TimescaleManager
|
||||||
|
from .connection_pool import DatabaseConnectionPool
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'TimescaleManager',
|
||||||
|
'DatabaseConnectionPool'
|
||||||
|
]
|
140
COBY/storage/connection_pool.py
Normal file
140
COBY/storage/connection_pool.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
Database connection pool management for TimescaleDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import asyncpg
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from ..config import config
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
from ..utils.exceptions import StorageError
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectionPool:
|
||||||
|
"""Manages database connection pool for TimescaleDB"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._pool: Optional[asyncpg.Pool] = None
|
||||||
|
self._is_initialized = False
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the connection pool"""
|
||||||
|
if self._is_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build connection string
|
||||||
|
dsn = (
|
||||||
|
f"postgresql://{config.database.user}:{config.database.password}"
|
||||||
|
f"@{config.database.host}:{config.database.port}/{config.database.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create connection pool
|
||||||
|
self._pool = await asyncpg.create_pool(
|
||||||
|
dsn,
|
||||||
|
min_size=5,
|
||||||
|
max_size=config.database.pool_size,
|
||||||
|
max_queries=50000,
|
||||||
|
max_inactive_connection_lifetime=300,
|
||||||
|
command_timeout=config.database.pool_timeout,
|
||||||
|
server_settings={
|
||||||
|
'search_path': config.database.schema,
|
||||||
|
'timezone': 'UTC'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._is_initialized = True
|
||||||
|
logger.info(f"Database connection pool initialized with {config.database.pool_size} connections")
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
await self.health_check()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize database connection pool: {e}")
|
||||||
|
raise StorageError(f"Database connection failed: {e}", "DB_INIT_ERROR")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the connection pool"""
|
||||||
|
if self._pool:
|
||||||
|
await self._pool.close()
|
||||||
|
self._pool = None
|
||||||
|
self._is_initialized = False
|
||||||
|
logger.info("Database connection pool closed")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_connection(self):
|
||||||
|
"""Get a database connection from the pool"""
|
||||||
|
if not self._is_initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
if not self._pool:
|
||||||
|
raise StorageError("Connection pool not initialized", "POOL_NOT_READY")
|
||||||
|
|
||||||
|
async with self._pool.acquire() as connection:
|
||||||
|
try:
|
||||||
|
yield connection
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database operation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_transaction(self):
|
||||||
|
"""Get a database transaction"""
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
async with conn.transaction():
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
async def execute_query(self, query: str, *args) -> Any:
|
||||||
|
"""Execute a query and return results"""
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
return await conn.fetch(query, *args)
|
||||||
|
|
||||||
|
async def execute_command(self, command: str, *args) -> str:
|
||||||
|
"""Execute a command and return status"""
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
return await conn.execute(command, *args)
|
||||||
|
|
||||||
|
async def execute_many(self, command: str, args_list) -> None:
|
||||||
|
"""Execute a command multiple times with different arguments"""
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
await conn.executemany(command, args_list)
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""Check database health"""
|
||||||
|
try:
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
result = await conn.fetchval("SELECT 1")
|
||||||
|
if result == 1:
|
||||||
|
logger.debug("Database health check passed")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("Database health check returned unexpected result")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_pool_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get connection pool statistics"""
|
||||||
|
if not self._pool:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'size': self._pool.get_size(),
|
||||||
|
'min_size': self._pool.get_min_size(),
|
||||||
|
'max_size': self._pool.get_max_size(),
|
||||||
|
'idle_size': self._pool.get_idle_size(),
|
||||||
|
'is_closing': self._pool.is_closing()
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if pool is initialized"""
|
||||||
|
return self._is_initialized
|
||||||
|
|
||||||
|
|
||||||
|
# Global connection pool instance
|
||||||
|
db_pool = DatabaseConnectionPool()
|
271
COBY/storage/migrations.py
Normal file
271
COBY/storage/migrations.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
"""
|
||||||
|
Database migration system for schema updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
from ..utils.exceptions import StorageError
|
||||||
|
from .connection_pool import db_pool
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Migration:
|
||||||
|
"""Base class for database migrations"""
|
||||||
|
|
||||||
|
def __init__(self, version: str, description: str):
|
||||||
|
self.version = version
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
async def up(self) -> None:
|
||||||
|
"""Apply the migration"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def down(self) -> None:
|
||||||
|
"""Rollback the migration"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationManager:
|
||||||
|
"""Manages database schema migrations"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.migrations: List[Migration] = []
|
||||||
|
|
||||||
|
def register_migration(self, migration: Migration) -> None:
|
||||||
|
"""Register a migration"""
|
||||||
|
self.migrations.append(migration)
|
||||||
|
# Sort by version
|
||||||
|
self.migrations.sort(key=lambda m: m.version)
|
||||||
|
|
||||||
|
async def initialize_migration_table(self) -> None:
|
||||||
|
"""Create migration tracking table"""
|
||||||
|
query = """
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.schema_migrations (
|
||||||
|
version VARCHAR(50) PRIMARY KEY,
|
||||||
|
description TEXT NOT NULL,
|
||||||
|
applied_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_pool.execute_command(query)
|
||||||
|
logger.debug("Migration table initialized")
|
||||||
|
|
||||||
|
async def get_applied_migrations(self) -> List[str]:
|
||||||
|
"""Get list of applied migration versions"""
|
||||||
|
try:
|
||||||
|
query = "SELECT version FROM market_data.schema_migrations ORDER BY version"
|
||||||
|
rows = await db_pool.execute_query(query)
|
||||||
|
return [row['version'] for row in rows]
|
||||||
|
except Exception:
|
||||||
|
# Table might not exist yet
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def apply_migration(self, migration: Migration) -> bool:
|
||||||
|
"""Apply a single migration"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Applying migration {migration.version}: {migration.description}")
|
||||||
|
|
||||||
|
async with db_pool.get_transaction() as conn:
|
||||||
|
# Apply the migration
|
||||||
|
await migration.up()
|
||||||
|
|
||||||
|
# Record the migration
|
||||||
|
await conn.execute(
|
||||||
|
"INSERT INTO market_data.schema_migrations (version, description) VALUES ($1, $2)",
|
||||||
|
migration.version,
|
||||||
|
migration.description
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Migration {migration.version} applied successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to apply migration {migration.version}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def rollback_migration(self, migration: Migration) -> bool:
|
||||||
|
"""Rollback a single migration"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Rolling back migration {migration.version}: {migration.description}")
|
||||||
|
|
||||||
|
async with db_pool.get_transaction() as conn:
|
||||||
|
# Rollback the migration
|
||||||
|
await migration.down()
|
||||||
|
|
||||||
|
# Remove the migration record
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM market_data.schema_migrations WHERE version = $1",
|
||||||
|
migration.version
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Migration {migration.version} rolled back successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to rollback migration {migration.version}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def migrate_up(self, target_version: str = None) -> bool:
|
||||||
|
"""Apply all pending migrations up to target version"""
|
||||||
|
try:
|
||||||
|
await self.initialize_migration_table()
|
||||||
|
applied_migrations = await self.get_applied_migrations()
|
||||||
|
|
||||||
|
pending_migrations = [
|
||||||
|
m for m in self.migrations
|
||||||
|
if m.version not in applied_migrations
|
||||||
|
]
|
||||||
|
|
||||||
|
if target_version:
|
||||||
|
pending_migrations = [
|
||||||
|
m for m in pending_migrations
|
||||||
|
if m.version <= target_version
|
||||||
|
]
|
||||||
|
|
||||||
|
if not pending_migrations:
|
||||||
|
logger.info("No pending migrations to apply")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(f"Applying {len(pending_migrations)} pending migrations")
|
||||||
|
|
||||||
|
for migration in pending_migrations:
|
||||||
|
if not await self.apply_migration(migration):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("All migrations applied successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Migration failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def migrate_down(self, target_version: str) -> bool:
|
||||||
|
"""Rollback migrations down to target version"""
|
||||||
|
try:
|
||||||
|
applied_migrations = await self.get_applied_migrations()
|
||||||
|
|
||||||
|
migrations_to_rollback = [
|
||||||
|
m for m in reversed(self.migrations)
|
||||||
|
if m.version in applied_migrations and m.version > target_version
|
||||||
|
]
|
||||||
|
|
||||||
|
if not migrations_to_rollback:
|
||||||
|
logger.info("No migrations to rollback")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(f"Rolling back {len(migrations_to_rollback)} migrations")
|
||||||
|
|
||||||
|
for migration in migrations_to_rollback:
|
||||||
|
if not await self.rollback_migration(migration):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("All migrations rolled back successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Migration rollback failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_migration_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get current migration status"""
|
||||||
|
try:
|
||||||
|
applied_migrations = await self.get_applied_migrations()
|
||||||
|
|
||||||
|
status = {
|
||||||
|
'total_migrations': len(self.migrations),
|
||||||
|
'applied_migrations': len(applied_migrations),
|
||||||
|
'pending_migrations': len(self.migrations) - len(applied_migrations),
|
||||||
|
'current_version': applied_migrations[-1] if applied_migrations else None,
|
||||||
|
'latest_version': self.migrations[-1].version if self.migrations else None,
|
||||||
|
'migrations': []
|
||||||
|
}
|
||||||
|
|
||||||
|
for migration in self.migrations:
|
||||||
|
status['migrations'].append({
|
||||||
|
'version': migration.version,
|
||||||
|
'description': migration.description,
|
||||||
|
'applied': migration.version in applied_migrations
|
||||||
|
})
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get migration status: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# Example migrations
|
||||||
|
class InitialSchemaMigration(Migration):
|
||||||
|
"""Initial schema creation migration"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("001", "Create initial schema and tables")
|
||||||
|
|
||||||
|
async def up(self) -> None:
|
||||||
|
"""Create initial schema"""
|
||||||
|
from .schema import DatabaseSchema
|
||||||
|
|
||||||
|
queries = DatabaseSchema.get_all_creation_queries()
|
||||||
|
for query in queries:
|
||||||
|
await db_pool.execute_command(query)
|
||||||
|
|
||||||
|
async def down(self) -> None:
|
||||||
|
"""Drop initial schema"""
|
||||||
|
# Drop tables in reverse order
|
||||||
|
tables = [
|
||||||
|
'system_metrics',
|
||||||
|
'exchange_status',
|
||||||
|
'ohlcv_data',
|
||||||
|
'heatmap_data',
|
||||||
|
'trade_events',
|
||||||
|
'order_book_snapshots'
|
||||||
|
]
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
await db_pool.execute_command(f"DROP TABLE IF EXISTS market_data.{table} CASCADE")
|
||||||
|
|
||||||
|
|
||||||
|
class AddIndexesMigration(Migration):
|
||||||
|
"""Add performance indexes migration"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("002", "Add performance indexes")
|
||||||
|
|
||||||
|
async def up(self) -> None:
|
||||||
|
"""Add indexes"""
|
||||||
|
from .schema import DatabaseSchema
|
||||||
|
|
||||||
|
queries = DatabaseSchema.get_index_creation_queries()
|
||||||
|
for query in queries:
|
||||||
|
await db_pool.execute_command(query)
|
||||||
|
|
||||||
|
async def down(self) -> None:
|
||||||
|
"""Drop indexes"""
|
||||||
|
indexes = [
|
||||||
|
'idx_order_book_symbol_exchange',
|
||||||
|
'idx_order_book_timestamp',
|
||||||
|
'idx_trade_events_symbol_exchange',
|
||||||
|
'idx_trade_events_timestamp',
|
||||||
|
'idx_trade_events_price',
|
||||||
|
'idx_heatmap_symbol_bucket',
|
||||||
|
'idx_heatmap_timestamp',
|
||||||
|
'idx_ohlcv_symbol_timeframe',
|
||||||
|
'idx_ohlcv_timestamp',
|
||||||
|
'idx_exchange_status_exchange',
|
||||||
|
'idx_exchange_status_timestamp',
|
||||||
|
'idx_system_metrics_name',
|
||||||
|
'idx_system_metrics_timestamp'
|
||||||
|
]
|
||||||
|
|
||||||
|
for index in indexes:
|
||||||
|
await db_pool.execute_command(f"DROP INDEX IF EXISTS market_data.{index}")
|
||||||
|
|
||||||
|
|
||||||
|
# Global migration manager
|
||||||
|
migration_manager = MigrationManager()
|
||||||
|
|
||||||
|
# Register default migrations
|
||||||
|
migration_manager.register_migration(InitialSchemaMigration())
|
||||||
|
migration_manager.register_migration(AddIndexesMigration())
|
256
COBY/storage/schema.py
Normal file
256
COBY/storage/schema.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
"""
|
||||||
|
Database schema management for TimescaleDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSchema:
|
||||||
|
"""Manages database schema creation and migrations"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_schema_creation_queries() -> List[str]:
|
||||||
|
"""Get list of queries to create the database schema"""
|
||||||
|
return [
|
||||||
|
# Create TimescaleDB extension
|
||||||
|
"CREATE EXTENSION IF NOT EXISTS timescaledb;",
|
||||||
|
|
||||||
|
# Create schema
|
||||||
|
"CREATE SCHEMA IF NOT EXISTS market_data;",
|
||||||
|
|
||||||
|
# Order book snapshots table
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.order_book_snapshots (
|
||||||
|
id BIGSERIAL,
|
||||||
|
symbol VARCHAR(20) NOT NULL,
|
||||||
|
exchange VARCHAR(20) NOT NULL,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL,
|
||||||
|
bids JSONB NOT NULL,
|
||||||
|
asks JSONB NOT NULL,
|
||||||
|
sequence_id BIGINT,
|
||||||
|
mid_price DECIMAL(20,8),
|
||||||
|
spread DECIMAL(20,8),
|
||||||
|
bid_volume DECIMAL(30,8),
|
||||||
|
ask_volume DECIMAL(30,8),
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (timestamp, symbol, exchange)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# Trade events table
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.trade_events (
|
||||||
|
id BIGSERIAL,
|
||||||
|
symbol VARCHAR(20) NOT NULL,
|
||||||
|
exchange VARCHAR(20) NOT NULL,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL,
|
||||||
|
price DECIMAL(20,8) NOT NULL,
|
||||||
|
size DECIMAL(30,8) NOT NULL,
|
||||||
|
side VARCHAR(4) NOT NULL,
|
||||||
|
trade_id VARCHAR(100) NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (timestamp, symbol, exchange, trade_id)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# Aggregated heatmap data table
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.heatmap_data (
|
||||||
|
symbol VARCHAR(20) NOT NULL,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL,
|
||||||
|
bucket_size DECIMAL(10,2) NOT NULL,
|
||||||
|
price_bucket DECIMAL(20,8) NOT NULL,
|
||||||
|
volume DECIMAL(30,8) NOT NULL,
|
||||||
|
side VARCHAR(3) NOT NULL,
|
||||||
|
exchange_count INTEGER NOT NULL,
|
||||||
|
exchanges JSONB,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (timestamp, symbol, bucket_size, price_bucket, side)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# OHLCV data table
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.ohlcv_data (
|
||||||
|
symbol VARCHAR(20) NOT NULL,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL,
|
||||||
|
timeframe VARCHAR(10) NOT NULL,
|
||||||
|
open_price DECIMAL(20,8) NOT NULL,
|
||||||
|
high_price DECIMAL(20,8) NOT NULL,
|
||||||
|
low_price DECIMAL(20,8) NOT NULL,
|
||||||
|
close_price DECIMAL(20,8) NOT NULL,
|
||||||
|
volume DECIMAL(30,8) NOT NULL,
|
||||||
|
trade_count INTEGER,
|
||||||
|
vwap DECIMAL(20,8),
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (timestamp, symbol, timeframe)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# Exchange status tracking table
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.exchange_status (
|
||||||
|
exchange VARCHAR(20) NOT NULL,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL,
|
||||||
|
status VARCHAR(20) NOT NULL,
|
||||||
|
last_message_time TIMESTAMPTZ,
|
||||||
|
error_message TEXT,
|
||||||
|
connection_count INTEGER DEFAULT 0,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (timestamp, exchange)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# System metrics table
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS market_data.system_metrics (
|
||||||
|
metric_name VARCHAR(50) NOT NULL,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL,
|
||||||
|
value DECIMAL(20,8) NOT NULL,
|
||||||
|
labels JSONB,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (timestamp, metric_name)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_hypertable_creation_queries() -> List[str]:
|
||||||
|
"""Get queries to create hypertables"""
|
||||||
|
return [
|
||||||
|
"SELECT create_hypertable('market_data.order_book_snapshots', 'timestamp', if_not_exists => TRUE);",
|
||||||
|
"SELECT create_hypertable('market_data.trade_events', 'timestamp', if_not_exists => TRUE);",
|
||||||
|
"SELECT create_hypertable('market_data.heatmap_data', 'timestamp', if_not_exists => TRUE);",
|
||||||
|
"SELECT create_hypertable('market_data.ohlcv_data', 'timestamp', if_not_exists => TRUE);",
|
||||||
|
"SELECT create_hypertable('market_data.exchange_status', 'timestamp', if_not_exists => TRUE);",
|
||||||
|
"SELECT create_hypertable('market_data.system_metrics', 'timestamp', if_not_exists => TRUE);"
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_index_creation_queries() -> List[str]:
|
||||||
|
"""Get queries to create indexes"""
|
||||||
|
return [
|
||||||
|
# Order book indexes
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_order_book_symbol_exchange ON market_data.order_book_snapshots (symbol, exchange, timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_order_book_timestamp ON market_data.order_book_snapshots (timestamp DESC);",
|
||||||
|
|
||||||
|
# Trade events indexes
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_trade_events_symbol_exchange ON market_data.trade_events (symbol, exchange, timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_trade_events_timestamp ON market_data.trade_events (timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_trade_events_price ON market_data.trade_events (symbol, price, timestamp DESC);",
|
||||||
|
|
||||||
|
# Heatmap data indexes
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_heatmap_symbol_bucket ON market_data.heatmap_data (symbol, bucket_size, timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_heatmap_timestamp ON market_data.heatmap_data (timestamp DESC);",
|
||||||
|
|
||||||
|
# OHLCV data indexes
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe ON market_data.ohlcv_data (symbol, timeframe, timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_ohlcv_timestamp ON market_data.ohlcv_data (timestamp DESC);",
|
||||||
|
|
||||||
|
# Exchange status indexes
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_exchange_status_exchange ON market_data.exchange_status (exchange, timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_exchange_status_timestamp ON market_data.exchange_status (timestamp DESC);",
|
||||||
|
|
||||||
|
# System metrics indexes
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_system_metrics_name ON market_data.system_metrics (metric_name, timestamp DESC);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON market_data.system_metrics (timestamp DESC);"
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_retention_policy_queries() -> List[str]:
|
||||||
|
"""Get queries to create retention policies"""
|
||||||
|
return [
|
||||||
|
"SELECT add_retention_policy('market_data.order_book_snapshots', INTERVAL '90 days', if_not_exists => TRUE);",
|
||||||
|
"SELECT add_retention_policy('market_data.trade_events', INTERVAL '90 days', if_not_exists => TRUE);",
|
||||||
|
"SELECT add_retention_policy('market_data.heatmap_data', INTERVAL '90 days', if_not_exists => TRUE);",
|
||||||
|
"SELECT add_retention_policy('market_data.ohlcv_data', INTERVAL '365 days', if_not_exists => TRUE);",
|
||||||
|
"SELECT add_retention_policy('market_data.exchange_status', INTERVAL '30 days', if_not_exists => TRUE);",
|
||||||
|
"SELECT add_retention_policy('market_data.system_metrics', INTERVAL '30 days', if_not_exists => TRUE);"
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_continuous_aggregate_queries() -> List[str]:
|
||||||
|
"""Get queries to create continuous aggregates"""
|
||||||
|
return [
|
||||||
|
# Hourly OHLCV aggregate
|
||||||
|
"""
|
||||||
|
CREATE MATERIALIZED VIEW IF NOT EXISTS market_data.hourly_ohlcv
|
||||||
|
WITH (timescaledb.continuous) AS
|
||||||
|
SELECT
|
||||||
|
symbol,
|
||||||
|
exchange,
|
||||||
|
time_bucket('1 hour', timestamp) AS hour,
|
||||||
|
first(price, timestamp) AS open_price,
|
||||||
|
max(price) AS high_price,
|
||||||
|
min(price) AS low_price,
|
||||||
|
last(price, timestamp) AS close_price,
|
||||||
|
sum(size) AS volume,
|
||||||
|
count(*) AS trade_count,
|
||||||
|
avg(price) AS vwap
|
||||||
|
FROM market_data.trade_events
|
||||||
|
GROUP BY symbol, exchange, hour
|
||||||
|
WITH NO DATA;
|
||||||
|
""",
|
||||||
|
|
||||||
|
# Add refresh policy for continuous aggregate
|
||||||
|
"""
|
||||||
|
SELECT add_continuous_aggregate_policy('market_data.hourly_ohlcv',
|
||||||
|
start_offset => INTERVAL '3 hours',
|
||||||
|
end_offset => INTERVAL '1 hour',
|
||||||
|
schedule_interval => INTERVAL '1 hour',
|
||||||
|
if_not_exists => TRUE);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_view_creation_queries() -> List[str]:
|
||||||
|
"""Get queries to create views"""
|
||||||
|
return [
|
||||||
|
# Latest order books view
|
||||||
|
"""
|
||||||
|
CREATE OR REPLACE VIEW market_data.latest_order_books AS
|
||||||
|
SELECT DISTINCT ON (symbol, exchange)
|
||||||
|
symbol,
|
||||||
|
exchange,
|
||||||
|
timestamp,
|
||||||
|
bids,
|
||||||
|
asks,
|
||||||
|
mid_price,
|
||||||
|
spread,
|
||||||
|
bid_volume,
|
||||||
|
ask_volume
|
||||||
|
FROM market_data.order_book_snapshots
|
||||||
|
ORDER BY symbol, exchange, timestamp DESC;
|
||||||
|
""",
|
||||||
|
|
||||||
|
# Latest heatmaps view
|
||||||
|
"""
|
||||||
|
CREATE OR REPLACE VIEW market_data.latest_heatmaps AS
|
||||||
|
SELECT DISTINCT ON (symbol, bucket_size, price_bucket, side)
|
||||||
|
symbol,
|
||||||
|
bucket_size,
|
||||||
|
price_bucket,
|
||||||
|
side,
|
||||||
|
timestamp,
|
||||||
|
volume,
|
||||||
|
exchange_count,
|
||||||
|
exchanges
|
||||||
|
FROM market_data.heatmap_data
|
||||||
|
ORDER BY symbol, bucket_size, price_bucket, side, timestamp DESC;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_creation_queries() -> List[str]:
|
||||||
|
"""Get all schema creation queries in order"""
|
||||||
|
queries = []
|
||||||
|
queries.extend(DatabaseSchema.get_schema_creation_queries())
|
||||||
|
queries.extend(DatabaseSchema.get_hypertable_creation_queries())
|
||||||
|
queries.extend(DatabaseSchema.get_index_creation_queries())
|
||||||
|
queries.extend(DatabaseSchema.get_retention_policy_queries())
|
||||||
|
queries.extend(DatabaseSchema.get_continuous_aggregate_queries())
|
||||||
|
queries.extend(DatabaseSchema.get_view_creation_queries())
|
||||||
|
return queries
|
604
COBY/storage/timescale_manager.py
Normal file
604
COBY/storage/timescale_manager.py
Normal file
@ -0,0 +1,604 @@
|
|||||||
|
"""
|
||||||
|
TimescaleDB storage manager implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from ..interfaces.storage_manager import StorageManager
|
||||||
|
from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData, SystemMetrics, PriceLevel
|
||||||
|
from ..utils.logging import get_logger, set_correlation_id
|
||||||
|
from ..utils.exceptions import StorageError, ValidationError
|
||||||
|
from ..utils.timing import get_current_timestamp
|
||||||
|
from .connection_pool import db_pool
|
||||||
|
from .schema import DatabaseSchema
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TimescaleManager(StorageManager):
|
||||||
|
"""TimescaleDB implementation of StorageManager interface"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._schema_initialized = False
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the storage manager"""
|
||||||
|
await db_pool.initialize()
|
||||||
|
await self.setup_database_schema()
|
||||||
|
logger.info("TimescaleDB storage manager initialized")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the storage manager"""
|
||||||
|
await db_pool.close()
|
||||||
|
logger.info("TimescaleDB storage manager closed")
|
||||||
|
|
||||||
|
def setup_database_schema(self) -> None:
|
||||||
|
"""Set up database schema and tables"""
|
||||||
|
async def _setup():
|
||||||
|
if self._schema_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
queries = DatabaseSchema.get_all_creation_queries()
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
try:
|
||||||
|
await db_pool.execute_command(query)
|
||||||
|
logger.debug(f"Executed schema query: {query[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
# Log but continue - some queries might fail if already exists
|
||||||
|
logger.warning(f"Schema query failed (continuing): {e}")
|
||||||
|
|
||||||
|
self._schema_initialized = True
|
||||||
|
logger.info("Database schema setup completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup database schema: {e}")
|
||||||
|
raise StorageError(f"Schema setup failed: {e}", "SCHEMA_SETUP_ERROR")
|
||||||
|
|
||||||
|
# Run async setup
|
||||||
|
import asyncio
|
||||||
|
if asyncio.get_event_loop().is_running():
|
||||||
|
asyncio.create_task(_setup())
|
||||||
|
else:
|
||||||
|
asyncio.run(_setup())
|
||||||
|
|
||||||
|
async def store_orderbook(self, data: OrderBookSnapshot) -> bool:
|
||||||
|
"""Store order book snapshot to database"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
# Convert price levels to JSON
|
||||||
|
bids_json = json.dumps([
|
||||||
|
{"price": float(level.price), "size": float(level.size), "count": level.count}
|
||||||
|
for level in data.bids
|
||||||
|
])
|
||||||
|
asks_json = json.dumps([
|
||||||
|
{"price": float(level.price), "size": float(level.size), "count": level.count}
|
||||||
|
for level in data.asks
|
||||||
|
])
|
||||||
|
|
||||||
|
query = """
|
||||||
|
INSERT INTO market_data.order_book_snapshots
|
||||||
|
(symbol, exchange, timestamp, bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_pool.execute_command(
|
||||||
|
query,
|
||||||
|
data.symbol,
|
||||||
|
data.exchange,
|
||||||
|
data.timestamp,
|
||||||
|
bids_json,
|
||||||
|
asks_json,
|
||||||
|
data.sequence_id,
|
||||||
|
float(data.mid_price) if data.mid_price else None,
|
||||||
|
float(data.spread) if data.spread else None,
|
||||||
|
float(data.bid_volume),
|
||||||
|
float(data.ask_volume)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Stored order book: {data.symbol}@{data.exchange}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store order book: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def store_trade(self, data: TradeEvent) -> bool:
|
||||||
|
"""Store trade event to database"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
INSERT INTO market_data.trade_events
|
||||||
|
(symbol, exchange, timestamp, price, size, side, trade_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_pool.execute_command(
|
||||||
|
query,
|
||||||
|
data.symbol,
|
||||||
|
data.exchange,
|
||||||
|
data.timestamp,
|
||||||
|
float(data.price),
|
||||||
|
float(data.size),
|
||||||
|
data.side,
|
||||||
|
data.trade_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Stored trade: {data.symbol}@{data.exchange} - {data.trade_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store trade: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def store_heatmap(self, data: HeatmapData) -> bool:
|
||||||
|
"""Store heatmap data to database"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
# Store each heatmap point
|
||||||
|
for point in data.data:
|
||||||
|
query = """
|
||||||
|
INSERT INTO market_data.heatmap_data
|
||||||
|
(symbol, timestamp, bucket_size, price_bucket, volume, side, exchange_count, exchanges)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
ON CONFLICT (timestamp, symbol, bucket_size, price_bucket, side)
|
||||||
|
DO UPDATE SET
|
||||||
|
volume = EXCLUDED.volume,
|
||||||
|
exchange_count = EXCLUDED.exchange_count,
|
||||||
|
exchanges = EXCLUDED.exchanges
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_pool.execute_command(
|
||||||
|
query,
|
||||||
|
data.symbol,
|
||||||
|
data.timestamp,
|
||||||
|
float(data.bucket_size),
|
||||||
|
float(point.price),
|
||||||
|
float(point.volume),
|
||||||
|
point.side,
|
||||||
|
1, # exchange_count - will be updated by aggregation
|
||||||
|
json.dumps([]) # exchanges - will be updated by aggregation
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Stored heatmap: {data.symbol} with {len(data.data)} points")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store heatmap: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def store_metrics(self, data: SystemMetrics) -> bool:
|
||||||
|
"""Store system metrics to database"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
# Store multiple metrics
|
||||||
|
metrics = [
|
||||||
|
('cpu_usage', data.cpu_usage),
|
||||||
|
('memory_usage', data.memory_usage),
|
||||||
|
('disk_usage', data.disk_usage),
|
||||||
|
('database_connections', data.database_connections),
|
||||||
|
('redis_connections', data.redis_connections),
|
||||||
|
('active_websockets', data.active_websockets),
|
||||||
|
('messages_per_second', data.messages_per_second),
|
||||||
|
('processing_latency', data.processing_latency)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = """
|
||||||
|
INSERT INTO market_data.system_metrics
|
||||||
|
(metric_name, timestamp, value, labels)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
for metric_name, value in metrics:
|
||||||
|
await db_pool.execute_command(
|
||||||
|
query,
|
||||||
|
metric_name,
|
||||||
|
data.timestamp,
|
||||||
|
float(value),
|
||||||
|
json.dumps(data.network_io)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Stored system metrics")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store metrics: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_historical_orderbooks(self, symbol: str, exchange: str,
|
||||||
|
start: datetime, end: datetime,
|
||||||
|
limit: Optional[int] = None) -> List[OrderBookSnapshot]:
|
||||||
|
"""Retrieve historical order book data"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT symbol, exchange, timestamp, bids, asks, sequence_id, mid_price, spread
|
||||||
|
FROM market_data.order_book_snapshots
|
||||||
|
WHERE symbol = $1 AND exchange = $2 AND timestamp >= $3 AND timestamp <= $4
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
|
rows = await db_pool.execute_query(query, symbol, exchange, start, end)
|
||||||
|
|
||||||
|
orderbooks = []
|
||||||
|
for row in rows:
|
||||||
|
# Parse JSON bid/ask data
|
||||||
|
bids_data = json.loads(row['bids'])
|
||||||
|
asks_data = json.loads(row['asks'])
|
||||||
|
|
||||||
|
bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count'))
|
||||||
|
for b in bids_data]
|
||||||
|
asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count'))
|
||||||
|
for a in asks_data]
|
||||||
|
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol=row['symbol'],
|
||||||
|
exchange=row['exchange'],
|
||||||
|
timestamp=row['timestamp'],
|
||||||
|
bids=bids,
|
||||||
|
asks=asks,
|
||||||
|
sequence_id=row['sequence_id']
|
||||||
|
)
|
||||||
|
orderbooks.append(orderbook)
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(orderbooks)} historical order books")
|
||||||
|
return orderbooks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get historical order books: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_historical_trades(self, symbol: str, exchange: str,
|
||||||
|
start: datetime, end: datetime,
|
||||||
|
limit: Optional[int] = None) -> List[TradeEvent]:
|
||||||
|
"""Retrieve historical trade data"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT symbol, exchange, timestamp, price, size, side, trade_id
|
||||||
|
FROM market_data.trade_events
|
||||||
|
WHERE symbol = $1 AND exchange = $2 AND timestamp >= $3 AND timestamp <= $4
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
|
rows = await db_pool.execute_query(query, symbol, exchange, start, end)
|
||||||
|
|
||||||
|
trades = []
|
||||||
|
for row in rows:
|
||||||
|
trade = TradeEvent(
|
||||||
|
symbol=row['symbol'],
|
||||||
|
exchange=row['exchange'],
|
||||||
|
timestamp=row['timestamp'],
|
||||||
|
price=float(row['price']),
|
||||||
|
size=float(row['size']),
|
||||||
|
side=row['side'],
|
||||||
|
trade_id=row['trade_id']
|
||||||
|
)
|
||||||
|
trades.append(trade)
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(trades)} historical trades")
|
||||||
|
return trades
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get historical trades: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_latest_orderbook(self, symbol: str, exchange: str) -> Optional[OrderBookSnapshot]:
|
||||||
|
"""Get latest order book snapshot"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT symbol, exchange, timestamp, bids, asks, sequence_id
|
||||||
|
FROM market_data.order_book_snapshots
|
||||||
|
WHERE symbol = $1 AND exchange = $2
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await db_pool.execute_query(query, symbol, exchange)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
row = rows[0]
|
||||||
|
bids_data = json.loads(row['bids'])
|
||||||
|
asks_data = json.loads(row['asks'])
|
||||||
|
|
||||||
|
bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count'))
|
||||||
|
for b in bids_data]
|
||||||
|
asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count'))
|
||||||
|
for a in asks_data]
|
||||||
|
|
||||||
|
return OrderBookSnapshot(
|
||||||
|
symbol=row['symbol'],
|
||||||
|
exchange=row['exchange'],
|
||||||
|
timestamp=row['timestamp'],
|
||||||
|
bids=bids,
|
||||||
|
asks=asks,
|
||||||
|
sequence_id=row['sequence_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get latest order book: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_latest_heatmap(self, symbol: str, bucket_size: float) -> Optional[HeatmapData]:
|
||||||
|
"""Get latest heatmap data"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT price_bucket, volume, side, timestamp
|
||||||
|
FROM market_data.heatmap_data
|
||||||
|
WHERE symbol = $1 AND bucket_size = $2
|
||||||
|
AND timestamp = (
|
||||||
|
SELECT MAX(timestamp)
|
||||||
|
FROM market_data.heatmap_data
|
||||||
|
WHERE symbol = $1 AND bucket_size = $2
|
||||||
|
)
|
||||||
|
ORDER BY price_bucket
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await db_pool.execute_query(query, symbol, bucket_size)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from ..models.core import HeatmapPoint
|
||||||
|
heatmap = HeatmapData(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=rows[0]['timestamp'],
|
||||||
|
bucket_size=bucket_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate max volume for intensity
|
||||||
|
max_volume = max(float(row['volume']) for row in rows)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
volume = float(row['volume'])
|
||||||
|
intensity = volume / max_volume if max_volume > 0 else 0.0
|
||||||
|
|
||||||
|
point = HeatmapPoint(
|
||||||
|
price=float(row['price_bucket']),
|
||||||
|
volume=volume,
|
||||||
|
intensity=intensity,
|
||||||
|
side=row['side']
|
||||||
|
)
|
||||||
|
heatmap.data.append(point)
|
||||||
|
|
||||||
|
return heatmap
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get latest heatmap: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_ohlcv_data(self, symbol: str, exchange: str, timeframe: str,
|
||||||
|
start: datetime, end: datetime) -> List[Dict[str, Any]]:
|
||||||
|
"""Get OHLCV candlestick data"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT timestamp, open_price, high_price, low_price, close_price, volume, trade_count, vwap
|
||||||
|
FROM market_data.ohlcv_data
|
||||||
|
WHERE symbol = $1 AND exchange = $2 AND timeframe = $3
|
||||||
|
AND timestamp >= $4 AND timestamp <= $5
|
||||||
|
ORDER BY timestamp
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await db_pool.execute_query(query, symbol, exchange, timeframe, start, end)
|
||||||
|
|
||||||
|
ohlcv_data = []
|
||||||
|
for row in rows:
|
||||||
|
ohlcv_data.append({
|
||||||
|
'timestamp': row['timestamp'],
|
||||||
|
'open': float(row['open_price']),
|
||||||
|
'high': float(row['high_price']),
|
||||||
|
'low': float(row['low_price']),
|
||||||
|
'close': float(row['close_price']),
|
||||||
|
'volume': float(row['volume']),
|
||||||
|
'trade_count': row['trade_count'],
|
||||||
|
'vwap': float(row['vwap']) if row['vwap'] else None
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(ohlcv_data)} OHLCV records")
|
||||||
|
return ohlcv_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get OHLCV data: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def batch_store_orderbooks(self, data: List[OrderBookSnapshot]) -> int:
|
||||||
|
"""Store multiple order book snapshots in batch"""
|
||||||
|
if not data:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
# Prepare batch data
|
||||||
|
batch_data = []
|
||||||
|
for orderbook in data:
|
||||||
|
bids_json = json.dumps([
|
||||||
|
{"price": float(level.price), "size": float(level.size), "count": level.count}
|
||||||
|
for level in orderbook.bids
|
||||||
|
])
|
||||||
|
asks_json = json.dumps([
|
||||||
|
{"price": float(level.price), "size": float(level.size), "count": level.count}
|
||||||
|
for level in orderbook.asks
|
||||||
|
])
|
||||||
|
|
||||||
|
batch_data.append((
|
||||||
|
orderbook.symbol,
|
||||||
|
orderbook.exchange,
|
||||||
|
orderbook.timestamp,
|
||||||
|
bids_json,
|
||||||
|
asks_json,
|
||||||
|
orderbook.sequence_id,
|
||||||
|
float(orderbook.mid_price) if orderbook.mid_price else None,
|
||||||
|
float(orderbook.spread) if orderbook.spread else None,
|
||||||
|
float(orderbook.bid_volume),
|
||||||
|
float(orderbook.ask_volume)
|
||||||
|
))
|
||||||
|
|
||||||
|
query = """
|
||||||
|
INSERT INTO market_data.order_book_snapshots
|
||||||
|
(symbol, exchange, timestamp, bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_pool.execute_many(query, batch_data)
|
||||||
|
|
||||||
|
logger.debug(f"Batch stored {len(data)} order books")
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to batch store order books: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def batch_store_trades(self, data: List[TradeEvent]) -> int:
|
||||||
|
"""Store multiple trade events in batch"""
|
||||||
|
if not data:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
# Prepare batch data
|
||||||
|
batch_data = [
|
||||||
|
(trade.symbol, trade.exchange, trade.timestamp, float(trade.price),
|
||||||
|
float(trade.size), trade.side, trade.trade_id)
|
||||||
|
for trade in data
|
||||||
|
]
|
||||||
|
|
||||||
|
query = """
|
||||||
|
INSERT INTO market_data.trade_events
|
||||||
|
(symbol, exchange, timestamp, price, size, side, trade_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_pool.execute_many(query, batch_data)
|
||||||
|
|
||||||
|
logger.debug(f"Batch stored {len(data)} trades")
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to batch store trades: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def cleanup_old_data(self, retention_days: int) -> int:
|
||||||
|
"""Clean up old data based on retention policy"""
|
||||||
|
try:
|
||||||
|
cutoff_time = get_current_timestamp().replace(
|
||||||
|
day=get_current_timestamp().day - retention_days
|
||||||
|
)
|
||||||
|
|
||||||
|
tables = [
|
||||||
|
'order_book_snapshots',
|
||||||
|
'trade_events',
|
||||||
|
'heatmap_data',
|
||||||
|
'exchange_status',
|
||||||
|
'system_metrics'
|
||||||
|
]
|
||||||
|
|
||||||
|
total_deleted = 0
|
||||||
|
for table in tables:
|
||||||
|
query = f"""
|
||||||
|
DELETE FROM market_data.{table}
|
||||||
|
WHERE timestamp < $1
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await db_pool.execute_command(query, cutoff_time)
|
||||||
|
# Extract number from result like "DELETE 1234"
|
||||||
|
deleted = int(result.split()[-1]) if result.split()[-1].isdigit() else 0
|
||||||
|
total_deleted += deleted
|
||||||
|
|
||||||
|
logger.debug(f"Cleaned up {deleted} records from {table}")
|
||||||
|
|
||||||
|
logger.info(f"Cleaned up {total_deleted} total records older than {retention_days} days")
|
||||||
|
return total_deleted
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cleanup old data: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def get_storage_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get storage statistics"""
|
||||||
|
try:
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
# Table sizes
|
||||||
|
size_query = """
|
||||||
|
SELECT
|
||||||
|
schemaname,
|
||||||
|
tablename,
|
||||||
|
pg_size_pretty(pg_total_relation_size(schemaname||'.'||tablename)) as size,
|
||||||
|
pg_total_relation_size(schemaname||'.'||tablename) as size_bytes
|
||||||
|
FROM pg_tables
|
||||||
|
WHERE schemaname = 'market_data'
|
||||||
|
ORDER BY size_bytes DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
size_rows = await db_pool.execute_query(size_query)
|
||||||
|
stats['table_sizes'] = [
|
||||||
|
{
|
||||||
|
'table': row['tablename'],
|
||||||
|
'size': row['size'],
|
||||||
|
'size_bytes': row['size_bytes']
|
||||||
|
}
|
||||||
|
for row in size_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# Record counts
|
||||||
|
tables = ['order_book_snapshots', 'trade_events', 'heatmap_data',
|
||||||
|
'ohlcv_data', 'exchange_status', 'system_metrics']
|
||||||
|
|
||||||
|
record_counts = {}
|
||||||
|
for table in tables:
|
||||||
|
count_query = f"SELECT COUNT(*) as count FROM market_data.{table}"
|
||||||
|
count_rows = await db_pool.execute_query(count_query)
|
||||||
|
record_counts[table] = count_rows[0]['count'] if count_rows else 0
|
||||||
|
|
||||||
|
stats['record_counts'] = record_counts
|
||||||
|
|
||||||
|
# Connection pool stats
|
||||||
|
stats['connection_pool'] = await db_pool.get_pool_stats()
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get storage stats: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""Check storage system health"""
|
||||||
|
try:
|
||||||
|
# Check database connection
|
||||||
|
if not await db_pool.health_check():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if tables exist
|
||||||
|
query = """
|
||||||
|
SELECT COUNT(*) as count
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_schema = 'market_data'
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await db_pool.execute_query(query)
|
||||||
|
table_count = rows[0]['count'] if rows else 0
|
||||||
|
|
||||||
|
if table_count < 6: # We expect 6 main tables
|
||||||
|
logger.warning(f"Expected 6 tables, found {table_count}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug("Storage health check passed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Storage health check failed: {e}")
|
||||||
|
return False
|
274
COBY/test_integration.py
Normal file
274
COBY/test_integration.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Integration test script for COBY system components.
|
||||||
|
Run this to test the TimescaleDB integration and basic functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add COBY to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from config import config
|
||||||
|
from storage.timescale_manager import TimescaleManager
|
||||||
|
from models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||||
|
from utils.logging import setup_logging, get_logger
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging(level='INFO', console_output=True)
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_database_connection():
|
||||||
|
"""Test basic database connectivity"""
|
||||||
|
logger.info("🔌 Testing database connection...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
manager = TimescaleManager()
|
||||||
|
await manager.initialize()
|
||||||
|
|
||||||
|
# Test health check
|
||||||
|
is_healthy = await manager.health_check()
|
||||||
|
if is_healthy:
|
||||||
|
logger.info("✅ Database connection: HEALTHY")
|
||||||
|
else:
|
||||||
|
logger.error("❌ Database connection: UNHEALTHY")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test storage stats
|
||||||
|
stats = await manager.get_storage_stats()
|
||||||
|
logger.info(f"📊 Found {len(stats.get('table_sizes', []))} tables")
|
||||||
|
|
||||||
|
for table_info in stats.get('table_sizes', []):
|
||||||
|
logger.info(f" 📋 {table_info['table']}: {table_info['size']}")
|
||||||
|
|
||||||
|
await manager.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Database test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_data_storage():
|
||||||
|
"""Test storing and retrieving data"""
|
||||||
|
logger.info("💾 Testing data storage operations...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
manager = TimescaleManager()
|
||||||
|
await manager.initialize()
|
||||||
|
|
||||||
|
# Create test order book
|
||||||
|
test_orderbook = OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="test_exchange",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[
|
||||||
|
PriceLevel(price=50000.0, size=1.5, count=3),
|
||||||
|
PriceLevel(price=49999.0, size=2.0, count=5)
|
||||||
|
],
|
||||||
|
asks=[
|
||||||
|
PriceLevel(price=50001.0, size=1.0, count=2),
|
||||||
|
PriceLevel(price=50002.0, size=1.5, count=4)
|
||||||
|
],
|
||||||
|
sequence_id=12345
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test storing order book
|
||||||
|
result = await manager.store_orderbook(test_orderbook)
|
||||||
|
if result:
|
||||||
|
logger.info("✅ Order book storage: SUCCESS")
|
||||||
|
else:
|
||||||
|
logger.error("❌ Order book storage: FAILED")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test retrieving order book
|
||||||
|
retrieved = await manager.get_latest_orderbook("BTCUSDT", "test_exchange")
|
||||||
|
if retrieved:
|
||||||
|
logger.info(f"✅ Order book retrieval: SUCCESS (mid_price: {retrieved.mid_price})")
|
||||||
|
else:
|
||||||
|
logger.error("❌ Order book retrieval: FAILED")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create test trade
|
||||||
|
test_trade = TradeEvent(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="test_exchange",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
price=50000.5,
|
||||||
|
size=0.1,
|
||||||
|
side="buy",
|
||||||
|
trade_id="test_trade_123"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test storing trade
|
||||||
|
result = await manager.store_trade(test_trade)
|
||||||
|
if result:
|
||||||
|
logger.info("✅ Trade storage: SUCCESS")
|
||||||
|
else:
|
||||||
|
logger.error("❌ Trade storage: FAILED")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await manager.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Data storage test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_operations():
|
||||||
|
"""Test batch storage operations"""
|
||||||
|
logger.info("📦 Testing batch operations...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
manager = TimescaleManager()
|
||||||
|
await manager.initialize()
|
||||||
|
|
||||||
|
# Create batch of order books
|
||||||
|
orderbooks = []
|
||||||
|
for i in range(5):
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol="ETHUSDT",
|
||||||
|
exchange="test_exchange",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[PriceLevel(price=3000.0 + i, size=1.0)],
|
||||||
|
asks=[PriceLevel(price=3001.0 + i, size=1.0)],
|
||||||
|
sequence_id=i
|
||||||
|
)
|
||||||
|
orderbooks.append(orderbook)
|
||||||
|
|
||||||
|
# Test batch storage
|
||||||
|
result = await manager.batch_store_orderbooks(orderbooks)
|
||||||
|
if result == 5:
|
||||||
|
logger.info(f"✅ Batch order book storage: SUCCESS ({result} records)")
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ Batch order book storage: PARTIAL ({result}/5 records)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create batch of trades
|
||||||
|
trades = []
|
||||||
|
for i in range(10):
|
||||||
|
trade = TradeEvent(
|
||||||
|
symbol="ETHUSDT",
|
||||||
|
exchange="test_exchange",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
price=3000.0 + (i * 0.1),
|
||||||
|
size=0.05,
|
||||||
|
side="buy" if i % 2 == 0 else "sell",
|
||||||
|
trade_id=f"batch_trade_{i}"
|
||||||
|
)
|
||||||
|
trades.append(trade)
|
||||||
|
|
||||||
|
# Test batch trade storage
|
||||||
|
result = await manager.batch_store_trades(trades)
|
||||||
|
if result == 10:
|
||||||
|
logger.info(f"✅ Batch trade storage: SUCCESS ({result} records)")
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ Batch trade storage: PARTIAL ({result}/10 records)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await manager.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Batch operations test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_configuration():
|
||||||
|
"""Test configuration system"""
|
||||||
|
logger.info("⚙️ Testing configuration system...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test database configuration
|
||||||
|
db_url = config.get_database_url()
|
||||||
|
logger.info(f"✅ Database URL: {db_url.replace(config.database.password, '***')}")
|
||||||
|
|
||||||
|
# Test Redis configuration
|
||||||
|
redis_url = config.get_redis_url()
|
||||||
|
logger.info(f"✅ Redis URL: {redis_url.replace(config.redis.password, '***')}")
|
||||||
|
|
||||||
|
# Test bucket sizes
|
||||||
|
btc_bucket = config.get_bucket_size('BTCUSDT')
|
||||||
|
eth_bucket = config.get_bucket_size('ETHUSDT')
|
||||||
|
logger.info(f"✅ Bucket sizes: BTC=${btc_bucket}, ETH=${eth_bucket}")
|
||||||
|
|
||||||
|
# Test configuration dict
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
logger.info(f"✅ Configuration loaded: {len(config_dict)} sections")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Configuration test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def run_all_tests():
|
||||||
|
"""Run all integration tests"""
|
||||||
|
logger.info("🚀 Starting COBY Integration Tests")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
("Configuration", test_configuration),
|
||||||
|
("Database Connection", test_database_connection),
|
||||||
|
("Data Storage", test_data_storage),
|
||||||
|
("Batch Operations", test_batch_operations)
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for test_name, test_func in tests:
|
||||||
|
logger.info(f"\n🧪 Running {test_name} test...")
|
||||||
|
try:
|
||||||
|
result = await test_func()
|
||||||
|
results.append((test_name, result))
|
||||||
|
if result:
|
||||||
|
logger.info(f"✅ {test_name}: PASSED")
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ {test_name}: FAILED")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ {test_name}: ERROR - {e}")
|
||||||
|
results.append((test_name, False))
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
logger.info("\n" + "=" * 50)
|
||||||
|
logger.info("📋 TEST SUMMARY")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
passed = sum(1 for _, result in results if result)
|
||||||
|
total = len(results)
|
||||||
|
|
||||||
|
for test_name, result in results:
|
||||||
|
status = "✅ PASSED" if result else "❌ FAILED"
|
||||||
|
logger.info(f"{test_name:20} {status}")
|
||||||
|
|
||||||
|
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
logger.info("🎉 All tests passed! System is ready.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error("⚠️ Some tests failed. Check configuration and database connection.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("COBY Integration Test Suite")
|
||||||
|
print("=" * 30)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
success = asyncio.run(run_all_tests())
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("\n🎉 Integration tests completed successfully!")
|
||||||
|
print("The system is ready for the next development phase.")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("\n❌ Integration tests failed!")
|
||||||
|
print("Please check the logs and fix any issues before proceeding.")
|
||||||
|
sys.exit(1)
|
3
COBY/tests/__init__.py
Normal file
3
COBY/tests/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Test suite for the COBY system.
|
||||||
|
"""
|
192
COBY/tests/test_timescale_manager.py
Normal file
192
COBY/tests/test_timescale_manager.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Tests for TimescaleDB storage manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from ..storage.timescale_manager import TimescaleManager
|
||||||
|
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def storage_manager():
|
||||||
|
"""Create and initialize storage manager for testing"""
|
||||||
|
manager = TimescaleManager()
|
||||||
|
await manager.initialize()
|
||||||
|
yield manager
|
||||||
|
await manager.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_orderbook():
|
||||||
|
"""Create sample order book for testing"""
|
||||||
|
return OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[
|
||||||
|
PriceLevel(price=50000.0, size=1.5, count=3),
|
||||||
|
PriceLevel(price=49999.0, size=2.0, count=5)
|
||||||
|
],
|
||||||
|
asks=[
|
||||||
|
PriceLevel(price=50001.0, size=1.0, count=2),
|
||||||
|
PriceLevel(price=50002.0, size=1.5, count=4)
|
||||||
|
],
|
||||||
|
sequence_id=12345
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_trade():
|
||||||
|
"""Create sample trade for testing"""
|
||||||
|
return TradeEvent(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
price=50000.5,
|
||||||
|
size=0.1,
|
||||||
|
side="buy",
|
||||||
|
trade_id="test_trade_123"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimescaleManager:
|
||||||
|
"""Test cases for TimescaleManager"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check(self, storage_manager):
|
||||||
|
"""Test storage health check"""
|
||||||
|
is_healthy = await storage_manager.health_check()
|
||||||
|
assert is_healthy is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_orderbook(self, storage_manager, sample_orderbook):
|
||||||
|
"""Test storing order book snapshot"""
|
||||||
|
result = await storage_manager.store_orderbook(sample_orderbook)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_trade(self, storage_manager, sample_trade):
|
||||||
|
"""Test storing trade event"""
|
||||||
|
result = await storage_manager.store_trade(sample_trade)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_latest_orderbook(self, storage_manager, sample_orderbook):
|
||||||
|
"""Test retrieving latest order book"""
|
||||||
|
# Store the order book first
|
||||||
|
await storage_manager.store_orderbook(sample_orderbook)
|
||||||
|
|
||||||
|
# Retrieve it
|
||||||
|
retrieved = await storage_manager.get_latest_orderbook(
|
||||||
|
sample_orderbook.symbol,
|
||||||
|
sample_orderbook.exchange
|
||||||
|
)
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.symbol == sample_orderbook.symbol
|
||||||
|
assert retrieved.exchange == sample_orderbook.exchange
|
||||||
|
assert len(retrieved.bids) == len(sample_orderbook.bids)
|
||||||
|
assert len(retrieved.asks) == len(sample_orderbook.asks)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_store_orderbooks(self, storage_manager):
|
||||||
|
"""Test batch storing order books"""
|
||||||
|
orderbooks = []
|
||||||
|
for i in range(5):
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol="ETHUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[PriceLevel(price=3000.0 + i, size=1.0)],
|
||||||
|
asks=[PriceLevel(price=3001.0 + i, size=1.0)],
|
||||||
|
sequence_id=i
|
||||||
|
)
|
||||||
|
orderbooks.append(orderbook)
|
||||||
|
|
||||||
|
result = await storage_manager.batch_store_orderbooks(orderbooks)
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_store_trades(self, storage_manager):
|
||||||
|
"""Test batch storing trades"""
|
||||||
|
trades = []
|
||||||
|
for i in range(5):
|
||||||
|
trade = TradeEvent(
|
||||||
|
symbol="ETHUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
price=3000.0 + i,
|
||||||
|
size=0.1,
|
||||||
|
side="buy" if i % 2 == 0 else "sell",
|
||||||
|
trade_id=f"test_trade_{i}"
|
||||||
|
)
|
||||||
|
trades.append(trade)
|
||||||
|
|
||||||
|
result = await storage_manager.batch_store_trades(trades)
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_storage_stats(self, storage_manager):
|
||||||
|
"""Test getting storage statistics"""
|
||||||
|
stats = await storage_manager.get_storage_stats()
|
||||||
|
|
||||||
|
assert isinstance(stats, dict)
|
||||||
|
assert 'table_sizes' in stats
|
||||||
|
assert 'record_counts' in stats
|
||||||
|
assert 'connection_pool' in stats
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_historical_data_retrieval(self, storage_manager, sample_orderbook, sample_trade):
|
||||||
|
"""Test retrieving historical data"""
|
||||||
|
# Store some data first
|
||||||
|
await storage_manager.store_orderbook(sample_orderbook)
|
||||||
|
await storage_manager.store_trade(sample_trade)
|
||||||
|
|
||||||
|
# Define time range
|
||||||
|
start_time = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
end_time = datetime.now(timezone.utc).replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||||
|
|
||||||
|
# Retrieve historical order books
|
||||||
|
orderbooks = await storage_manager.get_historical_orderbooks(
|
||||||
|
sample_orderbook.symbol,
|
||||||
|
sample_orderbook.exchange,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(orderbooks, list)
|
||||||
|
|
||||||
|
# Retrieve historical trades
|
||||||
|
trades = await storage_manager.get_historical_trades(
|
||||||
|
sample_trade.symbol,
|
||||||
|
sample_trade.exchange,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(trades, list)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run a simple test
|
||||||
|
async def simple_test():
|
||||||
|
manager = TimescaleManager()
|
||||||
|
await manager.initialize()
|
||||||
|
|
||||||
|
# Test health check
|
||||||
|
is_healthy = await manager.health_check()
|
||||||
|
print(f"Health check: {'PASSED' if is_healthy else 'FAILED'}")
|
||||||
|
|
||||||
|
# Test storage stats
|
||||||
|
stats = await manager.get_storage_stats()
|
||||||
|
print(f"Storage stats: {len(stats)} categories")
|
||||||
|
|
||||||
|
await manager.close()
|
||||||
|
print("Simple test completed")
|
||||||
|
|
||||||
|
asyncio.run(simple_test())
|
Reference in New Issue
Block a user