""" Unified Database Manager for TimescaleDB. Handles connection pooling, queries, and data persistence. """ import asyncio import asyncpg import logging import time from datetime import datetime, timedelta, timezone from typing import Dict, List, Optional, Tuple, Any from pathlib import Path import pandas as pd from .config import get_config from .unified_data_models import OHLCVCandle, OrderBookDataFrame, TradeEvent logger = logging.getLogger(__name__) class DatabaseConnectionManager: """ Manages database connection pool with health monitoring. Provides connection pooling and basic query methods. """ def __init__(self, config: Optional[Dict] = None): """ Initialize database connection manager. Args: config: Database configuration dictionary """ self.config = config or get_config() self.pool: Optional[asyncpg.Pool] = None self._connection_string = None self._health_check_task = None self._is_healthy = False # Connection statistics self.total_queries = 0 self.failed_queries = 0 self.total_query_time = 0.0 logger.info("DatabaseConnectionManager initialized") async def initialize(self) -> bool: """ Initialize connection pool. Returns: True if successful, False otherwise """ try: # Get database configuration db_config = self.config.get('database', {}) host = db_config.get('host', 'localhost') port = db_config.get('port', 5432) database = db_config.get('name', 'trading_data') user = db_config.get('user', 'postgres') password = db_config.get('password', 'postgres') pool_size = db_config.get('pool_size', 20) logger.info(f"Connecting to database: {host}:{port}/{database}") # Create connection pool self.pool = await asyncpg.create_pool( host=host, port=port, database=database, user=user, password=password, min_size=5, max_size=pool_size, command_timeout=60, server_settings={ 'jit': 'off', 'timezone': 'UTC', 'statement_timeout': '30s' } ) # Test connection async with self.pool.acquire() as conn: await conn.execute('SELECT 1') self._is_healthy = True logger.info("Database connection pool initialized successfully") # Start health monitoring self._health_check_task = asyncio.create_task(self._health_monitor()) return True except Exception as e: logger.error(f"Failed to initialize database connection: {e}") self._is_healthy = False return False async def close(self): """Close connection pool.""" try: if self._health_check_task: self._health_check_task.cancel() try: await self._health_check_task except asyncio.CancelledError: pass if self.pool: await self.pool.close() logger.info("Database connection pool closed") except Exception as e: logger.error(f"Error closing database connection: {e}") async def _health_monitor(self): """Background task to monitor connection health.""" while True: try: await asyncio.sleep(30) # Check every 30 seconds if self.pool: async with self.pool.acquire() as conn: await conn.execute('SELECT 1') self._is_healthy = True except asyncio.CancelledError: break except Exception as e: logger.error(f"Health check failed: {e}") self._is_healthy = False def is_healthy(self) -> bool: """Check if connection pool is healthy.""" return self._is_healthy and self.pool is not None async def execute(self, query: str, *args) -> str: """ Execute a query. Args: query: SQL query *args: Query parameters Returns: Query result status """ start_time = time.time() try: async with self.pool.acquire() as conn: result = await conn.execute(query, *args) self.total_queries += 1 self.total_query_time += (time.time() - start_time) return result except Exception as e: self.failed_queries += 1 logger.error(f"Query execution failed: {e}") raise async def fetch(self, query: str, *args) -> List[asyncpg.Record]: """ Fetch multiple rows. Args: query: SQL query *args: Query parameters Returns: List of records """ start_time = time.time() try: async with self.pool.acquire() as conn: result = await conn.fetch(query, *args) self.total_queries += 1 self.total_query_time += (time.time() - start_time) return result except Exception as e: self.failed_queries += 1 logger.error(f"Query fetch failed: {e}") raise async def fetchrow(self, query: str, *args) -> Optional[asyncpg.Record]: """ Fetch a single row. Args: query: SQL query *args: Query parameters Returns: Single record or None """ start_time = time.time() try: async with self.pool.acquire() as conn: result = await conn.fetchrow(query, *args) self.total_queries += 1 self.total_query_time += (time.time() - start_time) return result except Exception as e: self.failed_queries += 1 logger.error(f"Query fetchrow failed: {e}") raise async def fetchval(self, query: str, *args) -> Any: """ Fetch a single value. Args: query: SQL query *args: Query parameters Returns: Single value """ start_time = time.time() try: async with self.pool.acquire() as conn: result = await conn.fetchval(query, *args) self.total_queries += 1 self.total_query_time += (time.time() - start_time) return result except Exception as e: self.failed_queries += 1 logger.error(f"Query fetchval failed: {e}") raise async def executemany(self, query: str, args_list: List[Tuple]) -> None: """ Execute query with multiple parameter sets (batch insert). Args: query: SQL query args_list: List of parameter tuples """ start_time = time.time() try: async with self.pool.acquire() as conn: await conn.executemany(query, args_list) self.total_queries += 1 self.total_query_time += (time.time() - start_time) except Exception as e: self.failed_queries += 1 logger.error(f"Query executemany failed: {e}") raise def get_stats(self) -> Dict[str, Any]: """Get connection pool statistics.""" avg_query_time = (self.total_query_time / self.total_queries) if self.total_queries > 0 else 0 stats = { 'is_healthy': self._is_healthy, 'total_queries': self.total_queries, 'failed_queries': self.failed_queries, 'success_rate': ((self.total_queries - self.failed_queries) / self.total_queries * 100) if self.total_queries > 0 else 0, 'avg_query_time_ms': round(avg_query_time * 1000, 2), 'total_query_time_seconds': round(self.total_query_time, 2) } if self.pool: stats.update({ 'pool_size': self.pool.get_size(), 'pool_max_size': self.pool.get_max_size(), 'pool_min_size': self.pool.get_min_size() }) return stats class UnifiedDatabaseQueryManager: """ Manages database queries for unified storage system. Provides high-level query methods for OHLCV, order book, and trade data. """ def __init__(self, connection_manager: DatabaseConnectionManager): """ Initialize query manager. Args: connection_manager: Database connection manager """ self.db = connection_manager logger.info("UnifiedDatabaseQueryManager initialized") async def query_ohlcv_data( self, symbol: str, timeframe: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, limit: int = 1000 ) -> pd.DataFrame: """ Query OHLCV data for a single timeframe. Args: symbol: Trading symbol timeframe: Timeframe (1s, 1m, etc.) start_time: Start timestamp (None = no limit) end_time: End timestamp (None = now) limit: Maximum number of candles Returns: DataFrame with OHLCV data """ try: # Build query query = """ SELECT timestamp, symbol, timeframe, open_price, high_price, low_price, close_price, volume, trade_count, rsi_14, macd, macd_signal, macd_histogram, bb_upper, bb_middle, bb_lower, ema_12, ema_26, sma_20 FROM ohlcv_data WHERE symbol = $1 AND timeframe = $2 """ params = [symbol, timeframe] param_idx = 3 if start_time: query += f" AND timestamp >= ${param_idx}" params.append(start_time) param_idx += 1 if end_time: query += f" AND timestamp <= ${param_idx}" params.append(end_time) param_idx += 1 query += f" ORDER BY timestamp DESC LIMIT ${param_idx}" params.append(limit) # Execute query rows = await self.db.fetch(query, *params) if not rows: return pd.DataFrame() # Convert to DataFrame df = pd.DataFrame([dict(row) for row in rows]) df = df.sort_values('timestamp').reset_index(drop=True) logger.debug(f"Queried {len(df)} OHLCV rows for {symbol} {timeframe}") return df except Exception as e: logger.error(f"Error querying OHLCV data: {e}") return pd.DataFrame() async def query_multi_timeframe_ohlcv( self, symbol: str, timeframes: List[str], timestamp: Optional[datetime] = None, limit: int = 100 ) -> Dict[str, pd.DataFrame]: """ Query aligned multi-timeframe OHLCV data. Args: symbol: Trading symbol timeframes: List of timeframes timestamp: Target timestamp (None = latest) limit: Number of candles per timeframe Returns: Dictionary mapping timeframe to DataFrame """ try: result = {} # Query each timeframe for timeframe in timeframes: if timestamp: # Query around specific timestamp start_time = timestamp - timedelta(hours=24) # Get 24 hours of context end_time = timestamp else: # Query latest data start_time = None end_time = None df = await self.query_ohlcv_data( symbol=symbol, timeframe=timeframe, start_time=start_time, end_time=end_time, limit=limit ) result[timeframe] = df logger.debug(f"Queried multi-timeframe data for {symbol}: {list(result.keys())}") return result except Exception as e: logger.error(f"Error querying multi-timeframe OHLCV: {e}") return {} async def query_orderbook_snapshots( self, symbol: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, limit: int = 100 ) -> List[Dict]: """ Query order book snapshots. Args: symbol: Trading symbol start_time: Start timestamp end_time: End timestamp limit: Maximum number of snapshots Returns: List of order book snapshot dictionaries """ try: query = """ SELECT timestamp, symbol, exchange, bids, asks, mid_price, spread, bid_volume, ask_volume, sequence_id FROM order_book_snapshots WHERE symbol = $1 """ params = [symbol] param_idx = 2 if start_time: query += f" AND timestamp >= ${param_idx}" params.append(start_time) param_idx += 1 if end_time: query += f" AND timestamp <= ${param_idx}" params.append(end_time) param_idx += 1 query += f" ORDER BY timestamp DESC LIMIT ${param_idx}" params.append(limit) rows = await self.db.fetch(query, *params) result = [dict(row) for row in rows] result.reverse() # Oldest first logger.debug(f"Queried {len(result)} order book snapshots for {symbol}") return result except Exception as e: logger.error(f"Error querying order book snapshots: {e}") return [] async def query_orderbook_aggregated( self, symbol: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, limit: int = 300 ) -> pd.DataFrame: """ Query 1s aggregated order book data. Args: symbol: Trading symbol start_time: Start timestamp end_time: End timestamp limit: Maximum number of entries Returns: DataFrame with aggregated order book data """ try: query = """ SELECT timestamp, symbol, price_bucket, bid_volume, ask_volume, bid_count, ask_count, imbalance FROM order_book_1s_agg WHERE symbol = $1 """ params = [symbol] param_idx = 2 if start_time: query += f" AND timestamp >= ${param_idx}" params.append(start_time) param_idx += 1 if end_time: query += f" AND timestamp <= ${param_idx}" params.append(end_time) param_idx += 1 query += f" ORDER BY timestamp DESC, price_bucket LIMIT ${param_idx}" params.append(limit) rows = await self.db.fetch(query, *params) if not rows: return pd.DataFrame() df = pd.DataFrame([dict(row) for row in rows]) df = df.sort_values(['timestamp', 'price_bucket']).reset_index(drop=True) logger.debug(f"Queried {len(df)} aggregated order book rows for {symbol}") return df except Exception as e: logger.error(f"Error querying aggregated order book: {e}") return pd.DataFrame() async def query_orderbook_imbalances( self, symbol: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, limit: int = 300 ) -> pd.DataFrame: """ Query multi-timeframe order book imbalances. Args: symbol: Trading symbol start_time: Start timestamp end_time: End timestamp limit: Maximum number of entries Returns: DataFrame with imbalance metrics """ try: query = """ SELECT timestamp, symbol, imbalance_1s, imbalance_5s, imbalance_15s, imbalance_60s, volume_imbalance_1s, volume_imbalance_5s, volume_imbalance_15s, volume_imbalance_60s, price_range FROM order_book_imbalances WHERE symbol = $1 """ params = [symbol] param_idx = 2 if start_time: query += f" AND timestamp >= ${param_idx}" params.append(start_time) param_idx += 1 if end_time: query += f" AND timestamp <= ${param_idx}" params.append(end_time) param_idx += 1 query += f" ORDER BY timestamp DESC LIMIT ${param_idx}" params.append(limit) rows = await self.db.fetch(query, *params) if not rows: return pd.DataFrame() df = pd.DataFrame([dict(row) for row in rows]) df = df.sort_values('timestamp').reset_index(drop=True) logger.debug(f"Queried {len(df)} imbalance rows for {symbol}") return df except Exception as e: logger.error(f"Error querying order book imbalances: {e}") return pd.DataFrame() async def query_trades( self, symbol: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, limit: int = 1000 ) -> List[Dict]: """ Query trade events. Args: symbol: Trading symbol start_time: Start timestamp end_time: End timestamp limit: Maximum number of trades Returns: List of trade dictionaries """ try: query = """ SELECT timestamp, symbol, exchange, price, size, side, trade_id, is_buyer_maker FROM trade_events WHERE symbol = $1 """ params = [symbol] param_idx = 2 if start_time: query += f" AND timestamp >= ${param_idx}" params.append(start_time) param_idx += 1 if end_time: query += f" AND timestamp <= ${param_idx}" params.append(end_time) param_idx += 1 query += f" ORDER BY timestamp DESC LIMIT ${param_idx}" params.append(limit) rows = await self.db.fetch(query, *params) result = [dict(row) for row in rows] result.reverse() # Oldest first logger.debug(f"Queried {len(result)} trades for {symbol}") return result except Exception as e: logger.error(f"Error querying trades: {e}") return []