Files
gogo2/core/unified_database_manager.py
Dobromir Popov e993bc2831 cache, pivots wip
2025-10-20 15:21:44 +03:00

657 lines
20 KiB
Python

"""
Unified Database Manager for TimescaleDB.
Handles connection pooling, queries, and data persistence.
"""
import asyncio
import asyncpg
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple, Any
from pathlib import Path
import pandas as pd
from .config import get_config
from .unified_data_models import OHLCVCandle, OrderBookDataFrame, TradeEvent
logger = logging.getLogger(__name__)
class DatabaseConnectionManager:
"""
Manages database connection pool with health monitoring.
Provides connection pooling and basic query methods.
"""
def __init__(self, config: Optional[Dict] = None):
"""
Initialize database connection manager.
Args:
config: Database configuration dictionary
"""
self.config = config or get_config()
self.pool: Optional[asyncpg.Pool] = None
self._connection_string = None
self._health_check_task = None
self._is_healthy = False
# Connection statistics
self.total_queries = 0
self.failed_queries = 0
self.total_query_time = 0.0
logger.info("DatabaseConnectionManager initialized")
async def initialize(self) -> bool:
"""
Initialize connection pool.
Returns:
True if successful, False otherwise
"""
try:
# Get database configuration
db_config = self.config.get('database', {})
host = db_config.get('host', 'localhost')
port = db_config.get('port', 5432)
database = db_config.get('name', 'trading_data')
user = db_config.get('user', 'postgres')
password = db_config.get('password', 'postgres')
pool_size = db_config.get('pool_size', 20)
logger.info(f"Connecting to database: {host}:{port}/{database}")
# Create connection pool
self.pool = await asyncpg.create_pool(
host=host,
port=port,
database=database,
user=user,
password=password,
min_size=5,
max_size=pool_size,
command_timeout=60,
server_settings={
'jit': 'off',
'timezone': 'UTC',
'statement_timeout': '30s'
}
)
# Test connection
async with self.pool.acquire() as conn:
await conn.execute('SELECT 1')
self._is_healthy = True
logger.info("Database connection pool initialized successfully")
# Start health monitoring
self._health_check_task = asyncio.create_task(self._health_monitor())
return True
except Exception as e:
logger.error(f"Failed to initialize database connection: {e}")
self._is_healthy = False
return False
async def close(self):
"""Close connection pool."""
try:
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self.pool:
await self.pool.close()
logger.info("Database connection pool closed")
except Exception as e:
logger.error(f"Error closing database connection: {e}")
async def _health_monitor(self):
"""Background task to monitor connection health."""
while True:
try:
await asyncio.sleep(30) # Check every 30 seconds
if self.pool:
async with self.pool.acquire() as conn:
await conn.execute('SELECT 1')
self._is_healthy = True
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Health check failed: {e}")
self._is_healthy = False
def is_healthy(self) -> bool:
"""Check if connection pool is healthy."""
return self._is_healthy and self.pool is not None
async def execute(self, query: str, *args) -> str:
"""
Execute a query.
Args:
query: SQL query
*args: Query parameters
Returns:
Query result status
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.execute(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query execution failed: {e}")
raise
async def fetch(self, query: str, *args) -> List[asyncpg.Record]:
"""
Fetch multiple rows.
Args:
query: SQL query
*args: Query parameters
Returns:
List of records
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.fetch(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query fetch failed: {e}")
raise
async def fetchrow(self, query: str, *args) -> Optional[asyncpg.Record]:
"""
Fetch a single row.
Args:
query: SQL query
*args: Query parameters
Returns:
Single record or None
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.fetchrow(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query fetchrow failed: {e}")
raise
async def fetchval(self, query: str, *args) -> Any:
"""
Fetch a single value.
Args:
query: SQL query
*args: Query parameters
Returns:
Single value
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.fetchval(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query fetchval failed: {e}")
raise
async def executemany(self, query: str, args_list: List[Tuple]) -> None:
"""
Execute query with multiple parameter sets (batch insert).
Args:
query: SQL query
args_list: List of parameter tuples
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
await conn.executemany(query, args_list)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
except Exception as e:
self.failed_queries += 1
logger.error(f"Query executemany failed: {e}")
raise
def get_stats(self) -> Dict[str, Any]:
"""Get connection pool statistics."""
avg_query_time = (self.total_query_time / self.total_queries) if self.total_queries > 0 else 0
stats = {
'is_healthy': self._is_healthy,
'total_queries': self.total_queries,
'failed_queries': self.failed_queries,
'success_rate': ((self.total_queries - self.failed_queries) / self.total_queries * 100) if self.total_queries > 0 else 0,
'avg_query_time_ms': round(avg_query_time * 1000, 2),
'total_query_time_seconds': round(self.total_query_time, 2)
}
if self.pool:
stats.update({
'pool_size': self.pool.get_size(),
'pool_max_size': self.pool.get_max_size(),
'pool_min_size': self.pool.get_min_size()
})
return stats
class UnifiedDatabaseQueryManager:
"""
Manages database queries for unified storage system.
Provides high-level query methods for OHLCV, order book, and trade data.
"""
def __init__(self, connection_manager: DatabaseConnectionManager):
"""
Initialize query manager.
Args:
connection_manager: Database connection manager
"""
self.db = connection_manager
logger.info("UnifiedDatabaseQueryManager initialized")
async def query_ohlcv_data(
self,
symbol: str,
timeframe: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 1000
) -> pd.DataFrame:
"""
Query OHLCV data for a single timeframe.
Args:
symbol: Trading symbol
timeframe: Timeframe (1s, 1m, etc.)
start_time: Start timestamp (None = no limit)
end_time: End timestamp (None = now)
limit: Maximum number of candles
Returns:
DataFrame with OHLCV data
"""
try:
# Build query
query = """
SELECT timestamp, symbol, timeframe,
open_price, high_price, low_price, close_price, volume, trade_count,
rsi_14, macd, macd_signal, macd_histogram,
bb_upper, bb_middle, bb_lower,
ema_12, ema_26, sma_20
FROM ohlcv_data
WHERE symbol = $1 AND timeframe = $2
"""
params = [symbol, timeframe]
param_idx = 3
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
# Execute query
rows = await self.db.fetch(query, *params)
if not rows:
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame([dict(row) for row in rows])
df = df.sort_values('timestamp').reset_index(drop=True)
logger.debug(f"Queried {len(df)} OHLCV rows for {symbol} {timeframe}")
return df
except Exception as e:
logger.error(f"Error querying OHLCV data: {e}")
return pd.DataFrame()
async def query_multi_timeframe_ohlcv(
self,
symbol: str,
timeframes: List[str],
timestamp: Optional[datetime] = None,
limit: int = 100
) -> Dict[str, pd.DataFrame]:
"""
Query aligned multi-timeframe OHLCV data.
Args:
symbol: Trading symbol
timeframes: List of timeframes
timestamp: Target timestamp (None = latest)
limit: Number of candles per timeframe
Returns:
Dictionary mapping timeframe to DataFrame
"""
try:
result = {}
# Query each timeframe
for timeframe in timeframes:
if timestamp:
# Query around specific timestamp
start_time = timestamp - timedelta(hours=24) # Get 24 hours of context
end_time = timestamp
else:
# Query latest data
start_time = None
end_time = None
df = await self.query_ohlcv_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=limit
)
result[timeframe] = df
logger.debug(f"Queried multi-timeframe data for {symbol}: {list(result.keys())}")
return result
except Exception as e:
logger.error(f"Error querying multi-timeframe OHLCV: {e}")
return {}
async def query_orderbook_snapshots(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100
) -> List[Dict]:
"""
Query order book snapshots.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of snapshots
Returns:
List of order book snapshot dictionaries
"""
try:
query = """
SELECT timestamp, symbol, exchange, bids, asks,
mid_price, spread, bid_volume, ask_volume, sequence_id
FROM order_book_snapshots
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
result = [dict(row) for row in rows]
result.reverse() # Oldest first
logger.debug(f"Queried {len(result)} order book snapshots for {symbol}")
return result
except Exception as e:
logger.error(f"Error querying order book snapshots: {e}")
return []
async def query_orderbook_aggregated(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 300
) -> pd.DataFrame:
"""
Query 1s aggregated order book data.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of entries
Returns:
DataFrame with aggregated order book data
"""
try:
query = """
SELECT timestamp, symbol, price_bucket,
bid_volume, ask_volume, bid_count, ask_count, imbalance
FROM order_book_1s_agg
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC, price_bucket LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
if not rows:
return pd.DataFrame()
df = pd.DataFrame([dict(row) for row in rows])
df = df.sort_values(['timestamp', 'price_bucket']).reset_index(drop=True)
logger.debug(f"Queried {len(df)} aggregated order book rows for {symbol}")
return df
except Exception as e:
logger.error(f"Error querying aggregated order book: {e}")
return pd.DataFrame()
async def query_orderbook_imbalances(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 300
) -> pd.DataFrame:
"""
Query multi-timeframe order book imbalances.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of entries
Returns:
DataFrame with imbalance metrics
"""
try:
query = """
SELECT timestamp, symbol,
imbalance_1s, imbalance_5s, imbalance_15s, imbalance_60s,
volume_imbalance_1s, volume_imbalance_5s,
volume_imbalance_15s, volume_imbalance_60s,
price_range
FROM order_book_imbalances
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
if not rows:
return pd.DataFrame()
df = pd.DataFrame([dict(row) for row in rows])
df = df.sort_values('timestamp').reset_index(drop=True)
logger.debug(f"Queried {len(df)} imbalance rows for {symbol}")
return df
except Exception as e:
logger.error(f"Error querying order book imbalances: {e}")
return pd.DataFrame()
async def query_trades(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 1000
) -> List[Dict]:
"""
Query trade events.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of trades
Returns:
List of trade dictionaries
"""
try:
query = """
SELECT timestamp, symbol, exchange, price, size, side, trade_id, is_buyer_maker
FROM trade_events
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
result = [dict(row) for row in rows]
result.reverse() # Oldest first
logger.debug(f"Queried {len(result)} trades for {symbol}")
return result
except Exception as e:
logger.error(f"Error querying trades: {e}")
return []