cache, pivots wip

This commit is contained in:
Dobromir Popov
2025-10-20 15:21:44 +03:00
parent ba8813f04f
commit e993bc2831
9 changed files with 1630 additions and 99 deletions

View File

@@ -1894,20 +1894,24 @@ class DataProvider:
# Extract pivot points from all Williams levels
for level_key, level_data in pivot_levels.items():
if level_data and hasattr(level_data, 'swing_points') and level_data.swing_points:
# Get prices from swing points
level_prices = [sp.price for sp in level_data.swing_points]
# Update overall price bounds
price_max = max(price_max, max(level_prices))
price_min = min(price_min, min(level_prices))
# Extract support and resistance levels
if hasattr(level_data, 'support_levels') and level_data.support_levels:
support_levels.extend(level_data.support_levels)
if hasattr(level_data, 'resistance_levels') and level_data.resistance_levels:
resistance_levels.extend(level_data.resistance_levels)
if level_data and hasattr(level_data, 'pivot_points') and level_data.pivot_points:
# Separate pivot points into support and resistance based on type
for pivot in level_data.pivot_points:
price = getattr(pivot, 'price', None)
pivot_type = getattr(pivot, 'pivot_type', 'low')
if price is None:
continue
# Update price bounds
price_max = max(price_max, price)
price_min = min(price_min, price)
# Add to appropriate level list based on pivot type
if pivot_type.lower() == 'high':
resistance_levels.append(price)
else: # 'low'
support_levels.append(price)
# Remove duplicates and sort
support_levels = sorted(list(set(support_levels)))
@@ -1949,6 +1953,8 @@ class DataProvider:
volume_min=float(volume_min),
pivot_support_levels=support_levels,
pivot_resistance_levels=resistance_levels,
pivot_context=pivot_context,
created_timestamp=datetime.now(),
data_period_start=monthly_data['timestamp'].min(),
data_period_end=monthly_data['timestamp'].max(),
total_candles_analyzed=len(monthly_data)
@@ -2242,6 +2248,44 @@ class DataProvider:
"""Get pivot bounds for a symbol"""
return self.pivot_bounds.get(symbol)
def get_williams_pivot_levels(self, symbol: str) -> Dict[int, Any]:
"""Get Williams Market Structure pivot levels with full trend analysis
Returns:
Dictionary mapping level (1-5) to TrendLevel objects containing:
- pivot_points: List of PivotPoint objects with timestamps and prices
- trend_direction: 'up', 'down', or 'sideways'
- trend_strength: 0.0 to 1.0
"""
try:
if symbol not in self.williams_structure:
logger.warning(f"Williams structure not initialized for {symbol}")
return {}
# Calculate fresh pivot points from current cached data
df_1m = self.get_historical_data(symbol, '1m', limit=2000)
if df_1m is None or len(df_1m) < 100:
logger.warning(f"Insufficient 1m data for Williams pivot calculation: {symbol}")
return {}
# Convert DataFrame to numpy array
ohlcv_array = df_1m[['open', 'high', 'low', 'close', 'volume']].copy()
# Add timestamp as first column (convert to seconds)
timestamps = df_1m.index.astype(np.int64) // 10**9 # Convert to seconds
ohlcv_array.insert(0, 'timestamp', timestamps)
ohlcv_array = ohlcv_array.to_numpy()
# Calculate recursive pivot points
williams = self.williams_structure[symbol]
pivot_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
logger.debug(f"Retrieved Williams pivot levels for {symbol}: {len(pivot_levels)} levels")
return pivot_levels
except Exception as e:
logger.error(f"Error getting Williams pivot levels for {symbol}: {e}")
return {}
def get_pivot_normalized_features(self, symbol: str, df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""Get dataframe with pivot-normalized features"""
try:

View File

@@ -0,0 +1,547 @@
"""
Data Cache Manager for unified storage system.
Provides low-latency in-memory caching for real-time data access.
"""
import time
import logging
from collections import deque
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Deque, Any
from threading import Lock
import pandas as pd
from .unified_data_models import OHLCVCandle, OrderBookDataFrame, TradeEvent
logger = logging.getLogger(__name__)
class DataCacheManager:
"""
Manages in-memory cache for real-time data.
Provides <10ms latency for latest data access.
Cache Structure:
- OHLCV: Last 5 minutes per symbol per timeframe
- Order book: Last 5 minutes per symbol
- Imbalances: Last 5 minutes per symbol
- Trades: Last 5 minutes per symbol
"""
def __init__(self, cache_duration_seconds: int = 300):
"""
Initialize cache manager.
Args:
cache_duration_seconds: Duration to keep data in cache (default 5 minutes)
"""
self.cache_duration = cache_duration_seconds
self.cache_duration_td = timedelta(seconds=cache_duration_seconds)
# In-memory storage with thread-safe locks
self.lock = Lock()
# OHLCV cache: {symbol: {timeframe: deque of candles}}
self.ohlcv_cache: Dict[str, Dict[str, Deque[Dict]]] = {}
# Order book cache: {symbol: deque of snapshots}
self.orderbook_cache: Dict[str, Deque[Dict]] = {}
# Imbalance cache: {symbol: deque of imbalance metrics}
self.imbalance_cache: Dict[str, Deque[Dict]] = {}
# Trade cache: {symbol: deque of trades}
self.trade_cache: Dict[str, Deque[Dict]] = {}
# Cache statistics
self.cache_hits = 0
self.cache_misses = 0
self.total_inserts = 0
self.total_evictions = 0
# Last eviction time
self.last_eviction = datetime.now()
self.eviction_interval = timedelta(seconds=10) # Evict every 10 seconds
logger.info(f"DataCacheManager initialized with {cache_duration_seconds}s cache duration")
def add_ohlcv_candle(self, symbol: str, timeframe: str, candle: Dict):
"""
Add OHLCV candle to cache.
Args:
symbol: Trading symbol
timeframe: Timeframe (1s, 1m, etc.)
candle: Candle dictionary with OHLCV data
"""
with self.lock:
try:
# Initialize symbol cache if needed
if symbol not in self.ohlcv_cache:
self.ohlcv_cache[symbol] = {}
# Initialize timeframe cache if needed
if timeframe not in self.ohlcv_cache[symbol]:
# Calculate max items for this timeframe
max_items = self._calculate_max_items(timeframe)
self.ohlcv_cache[symbol][timeframe] = deque(maxlen=max_items)
# Add candle with timestamp
candle_with_ts = candle.copy()
if 'timestamp' not in candle_with_ts:
candle_with_ts['timestamp'] = datetime.now()
self.ohlcv_cache[symbol][timeframe].append(candle_with_ts)
self.total_inserts += 1
logger.debug(f"Added OHLCV candle to cache: {symbol} {timeframe}")
except Exception as e:
logger.error(f"Error adding OHLCV candle to cache: {e}")
def add_orderbook_snapshot(self, symbol: str, snapshot: Dict):
"""
Add order book snapshot to cache.
Args:
symbol: Trading symbol
snapshot: Order book snapshot dictionary
"""
with self.lock:
try:
# Initialize symbol cache if needed
if symbol not in self.orderbook_cache:
# 5 minutes at ~1 snapshot per second = 300 snapshots
self.orderbook_cache[symbol] = deque(maxlen=300)
# Add snapshot with timestamp
snapshot_with_ts = snapshot.copy()
if 'timestamp' not in snapshot_with_ts:
snapshot_with_ts['timestamp'] = datetime.now()
self.orderbook_cache[symbol].append(snapshot_with_ts)
self.total_inserts += 1
logger.debug(f"Added order book snapshot to cache: {symbol}")
except Exception as e:
logger.error(f"Error adding order book snapshot to cache: {e}")
def add_imbalance_data(self, symbol: str, imbalance: Dict):
"""
Add imbalance metrics to cache.
Args:
symbol: Trading symbol
imbalance: Imbalance metrics dictionary
"""
with self.lock:
try:
# Initialize symbol cache if needed
if symbol not in self.imbalance_cache:
# 5 minutes at 1 per second = 300 entries
self.imbalance_cache[symbol] = deque(maxlen=300)
# Add imbalance with timestamp
imbalance_with_ts = imbalance.copy()
if 'timestamp' not in imbalance_with_ts:
imbalance_with_ts['timestamp'] = datetime.now()
self.imbalance_cache[symbol].append(imbalance_with_ts)
self.total_inserts += 1
logger.debug(f"Added imbalance data to cache: {symbol}")
except Exception as e:
logger.error(f"Error adding imbalance data to cache: {e}")
def add_trade(self, symbol: str, trade: Dict):
"""
Add trade event to cache.
Args:
symbol: Trading symbol
trade: Trade event dictionary
"""
with self.lock:
try:
# Initialize symbol cache if needed
if symbol not in self.trade_cache:
# 5 minutes at ~10 trades per second = 3000 trades
self.trade_cache[symbol] = deque(maxlen=3000)
# Add trade with timestamp
trade_with_ts = trade.copy()
if 'timestamp' not in trade_with_ts:
trade_with_ts['timestamp'] = datetime.now()
self.trade_cache[symbol].append(trade_with_ts)
self.total_inserts += 1
logger.debug(f"Added trade to cache: {symbol}")
except Exception as e:
logger.error(f"Error adding trade to cache: {e}")
def get_latest_ohlcv(self, symbol: str, timeframe: str, limit: int = 100) -> List[Dict]:
"""
Get latest OHLCV candles from cache.
Args:
symbol: Trading symbol
timeframe: Timeframe
limit: Maximum number of candles to return
Returns:
List of candle dictionaries (most recent last)
"""
start_time = time.time()
with self.lock:
try:
# Check if symbol and timeframe exist in cache
if symbol not in self.ohlcv_cache or timeframe not in self.ohlcv_cache[symbol]:
self.cache_misses += 1
return []
# Get candles
candles = list(self.ohlcv_cache[symbol][timeframe])
# Return last N candles
result = candles[-limit:] if len(candles) > limit else candles
self.cache_hits += 1
latency_ms = (time.time() - start_time) * 1000
logger.debug(f"Retrieved {len(result)} OHLCV candles from cache in {latency_ms:.2f}ms")
return result
except Exception as e:
logger.error(f"Error getting OHLCV from cache: {e}")
self.cache_misses += 1
return []
def get_latest_orderbook(self, symbol: str) -> Optional[Dict]:
"""
Get latest order book snapshot from cache.
Args:
symbol: Trading symbol
Returns:
Latest order book snapshot or None
"""
start_time = time.time()
with self.lock:
try:
# Check if symbol exists in cache
if symbol not in self.orderbook_cache or not self.orderbook_cache[symbol]:
self.cache_misses += 1
return None
# Get latest snapshot
result = self.orderbook_cache[symbol][-1].copy()
self.cache_hits += 1
latency_ms = (time.time() - start_time) * 1000
logger.debug(f"Retrieved order book from cache in {latency_ms:.2f}ms")
return result
except Exception as e:
logger.error(f"Error getting order book from cache: {e}")
self.cache_misses += 1
return None
def get_latest_imbalances(self, symbol: str, limit: int = 60) -> List[Dict]:
"""
Get latest imbalance metrics from cache.
Args:
symbol: Trading symbol
limit: Maximum number of entries to return
Returns:
List of imbalance dictionaries (most recent last)
"""
start_time = time.time()
with self.lock:
try:
# Check if symbol exists in cache
if symbol not in self.imbalance_cache:
self.cache_misses += 1
return []
# Get imbalances
imbalances = list(self.imbalance_cache[symbol])
# Return last N entries
result = imbalances[-limit:] if len(imbalances) > limit else imbalances
self.cache_hits += 1
latency_ms = (time.time() - start_time) * 1000
logger.debug(f"Retrieved {len(result)} imbalances from cache in {latency_ms:.2f}ms")
return result
except Exception as e:
logger.error(f"Error getting imbalances from cache: {e}")
self.cache_misses += 1
return []
def get_latest_trades(self, symbol: str, limit: int = 100) -> List[Dict]:
"""
Get latest trades from cache.
Args:
symbol: Trading symbol
limit: Maximum number of trades to return
Returns:
List of trade dictionaries (most recent last)
"""
start_time = time.time()
with self.lock:
try:
# Check if symbol exists in cache
if symbol not in self.trade_cache:
self.cache_misses += 1
return []
# Get trades
trades = list(self.trade_cache[symbol])
# Return last N trades
result = trades[-limit:] if len(trades) > limit else trades
self.cache_hits += 1
latency_ms = (time.time() - start_time) * 1000
logger.debug(f"Retrieved {len(result)} trades from cache in {latency_ms:.2f}ms")
return result
except Exception as e:
logger.error(f"Error getting trades from cache: {e}")
self.cache_misses += 1
return []
def get_ohlcv_dataframe(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
"""
Get OHLCV data as pandas DataFrame.
Args:
symbol: Trading symbol
timeframe: Timeframe
limit: Maximum number of candles
Returns:
DataFrame with OHLCV data
"""
candles = self.get_latest_ohlcv(symbol, timeframe, limit)
if not candles:
return pd.DataFrame()
return pd.DataFrame(candles)
def evict_old_data(self):
"""Remove data older than cache duration."""
with self.lock:
try:
now = datetime.now()
cutoff_time = now - self.cache_duration_td
eviction_count = 0
# Evict old OHLCV data
for symbol in list(self.ohlcv_cache.keys()):
for timeframe in list(self.ohlcv_cache[symbol].keys()):
cache = self.ohlcv_cache[symbol][timeframe]
original_len = len(cache)
# Remove old entries
while cache and cache[0]['timestamp'] < cutoff_time:
cache.popleft()
eviction_count += 1
# Remove empty timeframe caches
if not cache:
del self.ohlcv_cache[symbol][timeframe]
# Remove empty symbol caches
if not self.ohlcv_cache[symbol]:
del self.ohlcv_cache[symbol]
# Evict old order book data
for symbol in list(self.orderbook_cache.keys()):
cache = self.orderbook_cache[symbol]
while cache and cache[0]['timestamp'] < cutoff_time:
cache.popleft()
eviction_count += 1
if not cache:
del self.orderbook_cache[symbol]
# Evict old imbalance data
for symbol in list(self.imbalance_cache.keys()):
cache = self.imbalance_cache[symbol]
while cache and cache[0]['timestamp'] < cutoff_time:
cache.popleft()
eviction_count += 1
if not cache:
del self.imbalance_cache[symbol]
# Evict old trade data
for symbol in list(self.trade_cache.keys()):
cache = self.trade_cache[symbol]
while cache and cache[0]['timestamp'] < cutoff_time:
cache.popleft()
eviction_count += 1
if not cache:
del self.trade_cache[symbol]
self.total_evictions += eviction_count
self.last_eviction = now
if eviction_count > 0:
logger.debug(f"Evicted {eviction_count} old entries from cache")
except Exception as e:
logger.error(f"Error evicting old data: {e}")
def auto_evict_if_needed(self):
"""Automatically evict old data if interval has passed."""
now = datetime.now()
if now - self.last_eviction >= self.eviction_interval:
self.evict_old_data()
def clear_cache(self, symbol: Optional[str] = None):
"""
Clear cache data.
Args:
symbol: Symbol to clear (None = clear all)
"""
with self.lock:
if symbol:
# Clear specific symbol
if symbol in self.ohlcv_cache:
del self.ohlcv_cache[symbol]
if symbol in self.orderbook_cache:
del self.orderbook_cache[symbol]
if symbol in self.imbalance_cache:
del self.imbalance_cache[symbol]
if symbol in self.trade_cache:
del self.trade_cache[symbol]
logger.info(f"Cleared cache for symbol: {symbol}")
else:
# Clear all
self.ohlcv_cache.clear()
self.orderbook_cache.clear()
self.imbalance_cache.clear()
self.trade_cache.clear()
logger.info("Cleared all cache data")
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with self.lock:
# Calculate cache sizes
ohlcv_count = sum(
sum(len(tf_cache) for tf_cache in symbol_cache.values())
for symbol_cache in self.ohlcv_cache.values()
)
orderbook_count = sum(len(cache) for cache in self.orderbook_cache.values())
imbalance_count = sum(len(cache) for cache in self.imbalance_cache.values())
trade_count = sum(len(cache) for cache in self.trade_cache.values())
total_requests = self.cache_hits + self.cache_misses
hit_rate = (self.cache_hits / total_requests * 100) if total_requests > 0 else 0
return {
'cache_duration_seconds': self.cache_duration,
'ohlcv_entries': ohlcv_count,
'orderbook_entries': orderbook_count,
'imbalance_entries': imbalance_count,
'trade_entries': trade_count,
'total_entries': ohlcv_count + orderbook_count + imbalance_count + trade_count,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'hit_rate_percent': round(hit_rate, 2),
'total_inserts': self.total_inserts,
'total_evictions': self.total_evictions,
'last_eviction': self.last_eviction.isoformat(),
'symbols_cached': {
'ohlcv': list(self.ohlcv_cache.keys()),
'orderbook': list(self.orderbook_cache.keys()),
'imbalance': list(self.imbalance_cache.keys()),
'trade': list(self.trade_cache.keys())
}
}
def _calculate_max_items(self, timeframe: str) -> int:
"""
Calculate maximum cache items for a timeframe.
Args:
timeframe: Timeframe string
Returns:
Maximum number of items to cache
"""
# Timeframe to seconds mapping
timeframe_seconds = {
'1s': 1,
'1m': 60,
'5m': 300,
'15m': 900,
'30m': 1800,
'1h': 3600,
'4h': 14400,
'1d': 86400
}
seconds = timeframe_seconds.get(timeframe, 60)
# Calculate how many candles fit in cache duration
max_items = self.cache_duration // seconds
# Ensure at least 10 items
return max(10, max_items)
def get_cache_summary(self) -> str:
"""Get human-readable cache summary."""
stats = self.get_cache_stats()
summary = f"""
Cache Summary:
--------------
Duration: {stats['cache_duration_seconds']}s
Total Entries: {stats['total_entries']}
- OHLCV: {stats['ohlcv_entries']}
- Order Book: {stats['orderbook_entries']}
- Imbalances: {stats['imbalance_entries']}
- Trades: {stats['trade_entries']}
Performance:
- Cache Hits: {stats['cache_hits']}
- Cache Misses: {stats['cache_misses']}
- Hit Rate: {stats['hit_rate_percent']}%
- Total Inserts: {stats['total_inserts']}
- Total Evictions: {stats['total_evictions']}
Last Eviction: {stats['last_eviction']}
"""
return summary

View File

@@ -0,0 +1,656 @@
"""
Unified Database Manager for TimescaleDB.
Handles connection pooling, queries, and data persistence.
"""
import asyncio
import asyncpg
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple, Any
from pathlib import Path
import pandas as pd
from .config import get_config
from .unified_data_models import OHLCVCandle, OrderBookDataFrame, TradeEvent
logger = logging.getLogger(__name__)
class DatabaseConnectionManager:
"""
Manages database connection pool with health monitoring.
Provides connection pooling and basic query methods.
"""
def __init__(self, config: Optional[Dict] = None):
"""
Initialize database connection manager.
Args:
config: Database configuration dictionary
"""
self.config = config or get_config()
self.pool: Optional[asyncpg.Pool] = None
self._connection_string = None
self._health_check_task = None
self._is_healthy = False
# Connection statistics
self.total_queries = 0
self.failed_queries = 0
self.total_query_time = 0.0
logger.info("DatabaseConnectionManager initialized")
async def initialize(self) -> bool:
"""
Initialize connection pool.
Returns:
True if successful, False otherwise
"""
try:
# Get database configuration
db_config = self.config.get('database', {})
host = db_config.get('host', 'localhost')
port = db_config.get('port', 5432)
database = db_config.get('name', 'trading_data')
user = db_config.get('user', 'postgres')
password = db_config.get('password', 'postgres')
pool_size = db_config.get('pool_size', 20)
logger.info(f"Connecting to database: {host}:{port}/{database}")
# Create connection pool
self.pool = await asyncpg.create_pool(
host=host,
port=port,
database=database,
user=user,
password=password,
min_size=5,
max_size=pool_size,
command_timeout=60,
server_settings={
'jit': 'off',
'timezone': 'UTC',
'statement_timeout': '30s'
}
)
# Test connection
async with self.pool.acquire() as conn:
await conn.execute('SELECT 1')
self._is_healthy = True
logger.info("Database connection pool initialized successfully")
# Start health monitoring
self._health_check_task = asyncio.create_task(self._health_monitor())
return True
except Exception as e:
logger.error(f"Failed to initialize database connection: {e}")
self._is_healthy = False
return False
async def close(self):
"""Close connection pool."""
try:
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self.pool:
await self.pool.close()
logger.info("Database connection pool closed")
except Exception as e:
logger.error(f"Error closing database connection: {e}")
async def _health_monitor(self):
"""Background task to monitor connection health."""
while True:
try:
await asyncio.sleep(30) # Check every 30 seconds
if self.pool:
async with self.pool.acquire() as conn:
await conn.execute('SELECT 1')
self._is_healthy = True
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Health check failed: {e}")
self._is_healthy = False
def is_healthy(self) -> bool:
"""Check if connection pool is healthy."""
return self._is_healthy and self.pool is not None
async def execute(self, query: str, *args) -> str:
"""
Execute a query.
Args:
query: SQL query
*args: Query parameters
Returns:
Query result status
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.execute(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query execution failed: {e}")
raise
async def fetch(self, query: str, *args) -> List[asyncpg.Record]:
"""
Fetch multiple rows.
Args:
query: SQL query
*args: Query parameters
Returns:
List of records
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.fetch(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query fetch failed: {e}")
raise
async def fetchrow(self, query: str, *args) -> Optional[asyncpg.Record]:
"""
Fetch a single row.
Args:
query: SQL query
*args: Query parameters
Returns:
Single record or None
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.fetchrow(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query fetchrow failed: {e}")
raise
async def fetchval(self, query: str, *args) -> Any:
"""
Fetch a single value.
Args:
query: SQL query
*args: Query parameters
Returns:
Single value
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
result = await conn.fetchval(query, *args)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
return result
except Exception as e:
self.failed_queries += 1
logger.error(f"Query fetchval failed: {e}")
raise
async def executemany(self, query: str, args_list: List[Tuple]) -> None:
"""
Execute query with multiple parameter sets (batch insert).
Args:
query: SQL query
args_list: List of parameter tuples
"""
start_time = time.time()
try:
async with self.pool.acquire() as conn:
await conn.executemany(query, args_list)
self.total_queries += 1
self.total_query_time += (time.time() - start_time)
except Exception as e:
self.failed_queries += 1
logger.error(f"Query executemany failed: {e}")
raise
def get_stats(self) -> Dict[str, Any]:
"""Get connection pool statistics."""
avg_query_time = (self.total_query_time / self.total_queries) if self.total_queries > 0 else 0
stats = {
'is_healthy': self._is_healthy,
'total_queries': self.total_queries,
'failed_queries': self.failed_queries,
'success_rate': ((self.total_queries - self.failed_queries) / self.total_queries * 100) if self.total_queries > 0 else 0,
'avg_query_time_ms': round(avg_query_time * 1000, 2),
'total_query_time_seconds': round(self.total_query_time, 2)
}
if self.pool:
stats.update({
'pool_size': self.pool.get_size(),
'pool_max_size': self.pool.get_max_size(),
'pool_min_size': self.pool.get_min_size()
})
return stats
class UnifiedDatabaseQueryManager:
"""
Manages database queries for unified storage system.
Provides high-level query methods for OHLCV, order book, and trade data.
"""
def __init__(self, connection_manager: DatabaseConnectionManager):
"""
Initialize query manager.
Args:
connection_manager: Database connection manager
"""
self.db = connection_manager
logger.info("UnifiedDatabaseQueryManager initialized")
async def query_ohlcv_data(
self,
symbol: str,
timeframe: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 1000
) -> pd.DataFrame:
"""
Query OHLCV data for a single timeframe.
Args:
symbol: Trading symbol
timeframe: Timeframe (1s, 1m, etc.)
start_time: Start timestamp (None = no limit)
end_time: End timestamp (None = now)
limit: Maximum number of candles
Returns:
DataFrame with OHLCV data
"""
try:
# Build query
query = """
SELECT timestamp, symbol, timeframe,
open_price, high_price, low_price, close_price, volume, trade_count,
rsi_14, macd, macd_signal, macd_histogram,
bb_upper, bb_middle, bb_lower,
ema_12, ema_26, sma_20
FROM ohlcv_data
WHERE symbol = $1 AND timeframe = $2
"""
params = [symbol, timeframe]
param_idx = 3
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
# Execute query
rows = await self.db.fetch(query, *params)
if not rows:
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame([dict(row) for row in rows])
df = df.sort_values('timestamp').reset_index(drop=True)
logger.debug(f"Queried {len(df)} OHLCV rows for {symbol} {timeframe}")
return df
except Exception as e:
logger.error(f"Error querying OHLCV data: {e}")
return pd.DataFrame()
async def query_multi_timeframe_ohlcv(
self,
symbol: str,
timeframes: List[str],
timestamp: Optional[datetime] = None,
limit: int = 100
) -> Dict[str, pd.DataFrame]:
"""
Query aligned multi-timeframe OHLCV data.
Args:
symbol: Trading symbol
timeframes: List of timeframes
timestamp: Target timestamp (None = latest)
limit: Number of candles per timeframe
Returns:
Dictionary mapping timeframe to DataFrame
"""
try:
result = {}
# Query each timeframe
for timeframe in timeframes:
if timestamp:
# Query around specific timestamp
start_time = timestamp - timedelta(hours=24) # Get 24 hours of context
end_time = timestamp
else:
# Query latest data
start_time = None
end_time = None
df = await self.query_ohlcv_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=limit
)
result[timeframe] = df
logger.debug(f"Queried multi-timeframe data for {symbol}: {list(result.keys())}")
return result
except Exception as e:
logger.error(f"Error querying multi-timeframe OHLCV: {e}")
return {}
async def query_orderbook_snapshots(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100
) -> List[Dict]:
"""
Query order book snapshots.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of snapshots
Returns:
List of order book snapshot dictionaries
"""
try:
query = """
SELECT timestamp, symbol, exchange, bids, asks,
mid_price, spread, bid_volume, ask_volume, sequence_id
FROM order_book_snapshots
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
result = [dict(row) for row in rows]
result.reverse() # Oldest first
logger.debug(f"Queried {len(result)} order book snapshots for {symbol}")
return result
except Exception as e:
logger.error(f"Error querying order book snapshots: {e}")
return []
async def query_orderbook_aggregated(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 300
) -> pd.DataFrame:
"""
Query 1s aggregated order book data.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of entries
Returns:
DataFrame with aggregated order book data
"""
try:
query = """
SELECT timestamp, symbol, price_bucket,
bid_volume, ask_volume, bid_count, ask_count, imbalance
FROM order_book_1s_agg
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC, price_bucket LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
if not rows:
return pd.DataFrame()
df = pd.DataFrame([dict(row) for row in rows])
df = df.sort_values(['timestamp', 'price_bucket']).reset_index(drop=True)
logger.debug(f"Queried {len(df)} aggregated order book rows for {symbol}")
return df
except Exception as e:
logger.error(f"Error querying aggregated order book: {e}")
return pd.DataFrame()
async def query_orderbook_imbalances(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 300
) -> pd.DataFrame:
"""
Query multi-timeframe order book imbalances.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of entries
Returns:
DataFrame with imbalance metrics
"""
try:
query = """
SELECT timestamp, symbol,
imbalance_1s, imbalance_5s, imbalance_15s, imbalance_60s,
volume_imbalance_1s, volume_imbalance_5s,
volume_imbalance_15s, volume_imbalance_60s,
price_range
FROM order_book_imbalances
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
if not rows:
return pd.DataFrame()
df = pd.DataFrame([dict(row) for row in rows])
df = df.sort_values('timestamp').reset_index(drop=True)
logger.debug(f"Queried {len(df)} imbalance rows for {symbol}")
return df
except Exception as e:
logger.error(f"Error querying order book imbalances: {e}")
return pd.DataFrame()
async def query_trades(
self,
symbol: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 1000
) -> List[Dict]:
"""
Query trade events.
Args:
symbol: Trading symbol
start_time: Start timestamp
end_time: End timestamp
limit: Maximum number of trades
Returns:
List of trade dictionaries
"""
try:
query = """
SELECT timestamp, symbol, exchange, price, size, side, trade_id, is_buyer_maker
FROM trade_events
WHERE symbol = $1
"""
params = [symbol]
param_idx = 2
if start_time:
query += f" AND timestamp >= ${param_idx}"
params.append(start_time)
param_idx += 1
if end_time:
query += f" AND timestamp <= ${param_idx}"
params.append(end_time)
param_idx += 1
query += f" ORDER BY timestamp DESC LIMIT ${param_idx}"
params.append(limit)
rows = await self.db.fetch(query, *params)
result = [dict(row) for row in rows]
result.reverse() # Oldest first
logger.debug(f"Queried {len(result)} trades for {symbol}")
return result
except Exception as e:
logger.error(f"Error querying trades: {e}")
return []