gogo2/core/data_provider.py
2025-05-24 11:42:02 +03:00

809 lines
35 KiB
Python

"""
Multi-Timeframe, Multi-Symbol Data Provider
This module consolidates all data functionality including:
- Historical data fetching from Binance API
- Real-time data streaming via WebSocket
- Multi-timeframe candle generation
- Caching and data management
- Technical indicators calculation
"""
import asyncio
import json
import logging
import os
import time
import websockets
import requests
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import ta
from threading import Thread, Lock
from collections import deque
from .config import get_config
logger = logging.getLogger(__name__)
class DataProvider:
"""Unified data provider for historical and real-time market data"""
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
"""Initialize the data provider"""
self.config = get_config()
self.symbols = symbols or self.config.symbols
self.timeframes = timeframes or self.config.timeframes
# Data storage
self.historical_data = {} # {symbol: {timeframe: DataFrame}}
self.real_time_data = {} # {symbol: {timeframe: deque}}
self.current_prices = {} # {symbol: float}
# Real-time processing
self.websocket_tasks = {}
self.is_streaming = False
self.data_lock = Lock()
# Cache settings
self.cache_enabled = self.config.data.get('cache_enabled', True)
self.cache_dir = Path(self.config.data.get('cache_dir', 'cache'))
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Timeframe conversion
self.timeframe_seconds = {
'1m': 60, '5m': 300, '15m': 900, '30m': 1800,
'1h': 3600, '4h': 14400, '1d': 86400
}
logger.info(f"DataProvider initialized for symbols: {self.symbols}")
logger.info(f"Timeframes: {self.timeframes}")
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
refresh: bool = False) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe"""
try:
# Check cache first
if not refresh and self.cache_enabled:
cached_data = self._load_from_cache(symbol, timeframe)
if cached_data is not None and len(cached_data) >= limit * 0.8:
logger.info(f"Using cached data for {symbol} {timeframe}")
return cached_data.tail(limit)
# Fetch from API
logger.info(f"Fetching historical data for {symbol} {timeframe}")
df = self._fetch_from_binance(symbol, timeframe, limit)
if df is not None and not df.empty:
# Add technical indicators
df = self._add_technical_indicators(df)
# Cache the data
if self.cache_enabled:
self._save_to_cache(df, symbol, timeframe)
# Store in memory
if symbol not in self.historical_data:
self.historical_data[symbol] = {}
self.historical_data[symbol][timeframe] = df
return df
logger.warning(f"No data received for {symbol} {timeframe}")
return None
except Exception as e:
logger.error(f"Error fetching historical data for {symbol} {timeframe}: {e}")
return None
def _fetch_from_binance(self, symbol: str, timeframe: str, limit: int) -> Optional[pd.DataFrame]:
"""Fetch data from Binance API"""
try:
# Convert symbol format
binance_symbol = symbol.replace('/', '').upper()
# Convert timeframe (now includes 1s support)
timeframe_map = {
'1s': '1s', '1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
binance_timeframe = timeframe_map.get(timeframe, '1h')
# API request
url = "https://api.binance.com/api/v3/klines"
params = {
'symbol': binance_symbol,
'interval': binance_timeframe,
'limit': limit
}
response = requests.get(url, params=params)
response.raise_for_status()
data = response.json()
# Convert to DataFrame
df = pd.DataFrame(data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
'taker_buy_quote', 'ignore'
])
# Process columns
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
# Keep only OHLCV columns
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.sort_values('timestamp').reset_index(drop=True)
logger.info(f"Fetched {len(df)} candles for {symbol} {timeframe}")
return df
except Exception as e:
logger.error(f"Error fetching from Binance API: {e}")
return None
def _add_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add comprehensive technical indicators for multi-timeframe analysis"""
try:
df = df.copy()
# Ensure we have enough data for indicators
if len(df) < 50:
logger.warning(f"Insufficient data for comprehensive indicators: {len(df)} rows")
return self._add_basic_indicators(df)
# === TREND INDICATORS ===
# Moving averages (multiple timeframes)
df['sma_10'] = ta.trend.sma_indicator(df['close'], window=10)
df['sma_20'] = ta.trend.sma_indicator(df['close'], window=20)
df['sma_50'] = ta.trend.sma_indicator(df['close'], window=50)
df['ema_12'] = ta.trend.ema_indicator(df['close'], window=12)
df['ema_26'] = ta.trend.ema_indicator(df['close'], window=26)
df['ema_50'] = ta.trend.ema_indicator(df['close'], window=50)
# MACD family
macd = ta.trend.MACD(df['close'])
df['macd'] = macd.macd()
df['macd_signal'] = macd.macd_signal()
df['macd_histogram'] = macd.macd_diff()
# ADX (Average Directional Index)
adx = ta.trend.ADXIndicator(df['high'], df['low'], df['close'])
df['adx'] = adx.adx()
df['adx_pos'] = adx.adx_pos()
df['adx_neg'] = adx.adx_neg()
# Parabolic SAR
psar = ta.trend.PSARIndicator(df['high'], df['low'], df['close'])
df['psar'] = psar.psar()
# === MOMENTUM INDICATORS ===
# RSI (multiple periods)
df['rsi_14'] = ta.momentum.rsi(df['close'], window=14)
df['rsi_7'] = ta.momentum.rsi(df['close'], window=7)
df['rsi_21'] = ta.momentum.rsi(df['close'], window=21)
# Stochastic Oscillator
stoch = ta.momentum.StochasticOscillator(df['high'], df['low'], df['close'])
df['stoch_k'] = stoch.stoch()
df['stoch_d'] = stoch.stoch_signal()
# Williams %R
df['williams_r'] = ta.momentum.williams_r(df['high'], df['low'], df['close'])
# Ultimate Oscillator (instead of CCI which isn't available)
df['ultimate_osc'] = ta.momentum.ultimate_oscillator(df['high'], df['low'], df['close'])
# === VOLATILITY INDICATORS ===
# Bollinger Bands
bollinger = ta.volatility.BollingerBands(df['close'])
df['bb_upper'] = bollinger.bollinger_hband()
df['bb_lower'] = bollinger.bollinger_lband()
df['bb_middle'] = bollinger.bollinger_mavg()
df['bb_width'] = (df['bb_upper'] - df['bb_lower']) / df['bb_middle']
df['bb_percent'] = (df['close'] - df['bb_lower']) / (df['bb_upper'] - df['bb_lower'])
# Average True Range
df['atr'] = ta.volatility.average_true_range(df['high'], df['low'], df['close'])
# Keltner Channels
keltner = ta.volatility.KeltnerChannel(df['high'], df['low'], df['close'])
df['keltner_upper'] = keltner.keltner_channel_hband()
df['keltner_lower'] = keltner.keltner_channel_lband()
df['keltner_middle'] = keltner.keltner_channel_mband()
# === VOLUME INDICATORS ===
# Volume moving averages
df['volume_sma_10'] = df['volume'].rolling(window=10).mean()
df['volume_sma_20'] = df['volume'].rolling(window=20).mean()
df['volume_sma_50'] = df['volume'].rolling(window=50).mean()
# On Balance Volume
df['obv'] = ta.volume.on_balance_volume(df['close'], df['volume'])
# Volume Price Trend
df['vpt'] = ta.volume.volume_price_trend(df['close'], df['volume'])
# Money Flow Index
df['mfi'] = ta.volume.money_flow_index(df['high'], df['low'], df['close'], df['volume'])
# Accumulation/Distribution Line
df['ad_line'] = ta.volume.acc_dist_index(df['high'], df['low'], df['close'], df['volume'])
# Volume Weighted Average Price (VWAP)
df['vwap'] = (df['close'] * df['volume']).cumsum() / df['volume'].cumsum()
# === PRICE ACTION INDICATORS ===
# Price position relative to range
df['price_position'] = (df['close'] - df['low']) / (df['high'] - df['low'])
# True Range (use ATR calculation for true range)
df['true_range'] = df['atr'] # ATR is based on true range, so use it directly
# Rate of Change
df['roc'] = ta.momentum.roc(df['close'], window=10)
# === CUSTOM INDICATORS ===
# Trend strength (combination of multiple trend indicators)
df['trend_strength'] = (
(df['close'] > df['sma_20']).astype(int) +
(df['sma_10'] > df['sma_20']).astype(int) +
(df['macd'] > df['macd_signal']).astype(int) +
(df['adx'] > 25).astype(int)
) / 4.0
# Momentum composite
df['momentum_composite'] = (
(df['rsi_14'] / 100) +
((df['stoch_k'] + 50) / 100) + # Normalize stoch_k
((df['williams_r'] + 50) / 100) # Normalize williams_r
) / 3.0
# Volatility regime
df['volatility_regime'] = (df['atr'] / df['close']).rolling(window=20).rank(pct=True)
# === FILL NaN VALUES ===
# Forward fill first, then backward fill, then zero fill
df = df.ffill().bfill().fillna(0)
logger.debug(f"Added {len([col for col in df.columns if col not in ['timestamp', 'open', 'high', 'low', 'close', 'volume']])} technical indicators")
return df
except Exception as e:
logger.error(f"Error adding comprehensive technical indicators: {e}")
# Fallback to basic indicators
return self._add_basic_indicators(df)
def _add_basic_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add basic indicators for small datasets"""
try:
df = df.copy()
# Basic moving averages
if len(df) >= 20:
df['sma_20'] = ta.trend.sma_indicator(df['close'], window=20)
df['ema_12'] = ta.trend.ema_indicator(df['close'], window=12)
# Basic RSI
if len(df) >= 14:
df['rsi_14'] = ta.momentum.rsi(df['close'], window=14)
# Basic volume indicators
if len(df) >= 10:
df['volume_sma_10'] = df['volume'].rolling(window=10).mean()
# Basic price action
df['price_position'] = (df['close'] - df['low']) / (df['high'] - df['low'])
df['price_position'] = df['price_position'].fillna(0.5) # Default to middle
# Fill NaN values
df = df.ffill().bfill().fillna(0)
return df
except Exception as e:
logger.error(f"Error adding basic indicators: {e}")
return df
def _load_from_cache(self, symbol: str, timeframe: str) -> Optional[pd.DataFrame]:
"""Load data from cache"""
try:
cache_file = self.cache_dir / f"{symbol.replace('/', '')}_{timeframe}.parquet"
if cache_file.exists():
# Check if cache is recent (less than 1 hour old)
cache_age = time.time() - cache_file.stat().st_mtime
if cache_age < 3600: # 1 hour
df = pd.read_parquet(cache_file)
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe}")
return df
else:
logger.debug(f"Cache for {symbol} {timeframe} is too old ({cache_age/3600:.1f}h)")
return None
except Exception as e:
logger.warning(f"Error loading cache for {symbol} {timeframe}: {e}")
return None
def _save_to_cache(self, df: pd.DataFrame, symbol: str, timeframe: str):
"""Save data to cache"""
try:
cache_file = self.cache_dir / f"{symbol.replace('/', '')}_{timeframe}.parquet"
df.to_parquet(cache_file, index=False)
logger.debug(f"Saved {len(df)} rows to cache for {symbol} {timeframe}")
except Exception as e:
logger.warning(f"Error saving cache for {symbol} {timeframe}: {e}")
async def start_real_time_streaming(self):
"""Start real-time data streaming for all symbols"""
if self.is_streaming:
logger.warning("Real-time streaming already active")
return
self.is_streaming = True
logger.info("Starting real-time data streaming")
# Start WebSocket for each symbol
for symbol in self.symbols:
task = asyncio.create_task(self._websocket_stream(symbol))
self.websocket_tasks[symbol] = task
async def stop_real_time_streaming(self):
"""Stop real-time data streaming"""
if not self.is_streaming:
return
logger.info("Stopping real-time data streaming")
self.is_streaming = False
# Cancel all WebSocket tasks
for symbol, task in self.websocket_tasks.items():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.websocket_tasks.clear()
async def _websocket_stream(self, symbol: str):
"""WebSocket stream for a single symbol"""
binance_symbol = symbol.replace('/', '').lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@ticker"
while self.is_streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"WebSocket connected for {symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_tick(symbol, data)
except Exception as e:
logger.warning(f"Error processing tick for {symbol}: {e}")
except Exception as e:
logger.error(f"WebSocket error for {symbol}: {e}")
if self.is_streaming:
logger.info(f"Reconnecting WebSocket for {symbol} in 5 seconds...")
await asyncio.sleep(5)
async def _process_tick(self, symbol: str, tick_data: Dict):
"""Process a single tick and update candles"""
try:
price = float(tick_data.get('c', 0)) # Current price
volume = float(tick_data.get('v', 0)) # 24h Volume
timestamp = pd.Timestamp.now()
# Update current price
with self.data_lock:
self.current_prices[symbol] = price
# Initialize real-time data structure if needed
if symbol not in self.real_time_data:
self.real_time_data[symbol] = {}
for tf in self.timeframes:
self.real_time_data[symbol][tf] = deque(maxlen=1000)
# Create tick record
tick = {
'timestamp': timestamp,
'price': price,
'volume': volume
}
# Update all timeframes
for timeframe in self.timeframes:
self._update_candle(symbol, timeframe, tick)
except Exception as e:
logger.error(f"Error processing tick for {symbol}: {e}")
def _update_candle(self, symbol: str, timeframe: str, tick: Dict):
"""Update candle for specific timeframe"""
try:
timeframe_secs = self.timeframe_seconds.get(timeframe, 3600)
current_time = tick['timestamp']
# Calculate candle start time
candle_start = current_time.floor(f'{timeframe_secs}s')
# Get current candle queue
candle_queue = self.real_time_data[symbol][timeframe]
# Check if we need a new candle
if not candle_queue or candle_queue[-1]['timestamp'] != candle_start:
# Create new candle
new_candle = {
'timestamp': candle_start,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
candle_queue.append(new_candle)
else:
# Update existing candle
current_candle = candle_queue[-1]
current_candle['high'] = max(current_candle['high'], tick['price'])
current_candle['low'] = min(current_candle['low'], tick['price'])
current_candle['close'] = tick['price']
current_candle['volume'] += tick['volume']
except Exception as e:
logger.error(f"Error updating candle for {symbol} {timeframe}: {e}")
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
"""Get the latest candles combining historical and real-time data"""
try:
# Get historical data
historical_df = self.get_historical_data(symbol, timeframe, limit=limit)
# Get real-time data
with self.data_lock:
if symbol in self.real_time_data and timeframe in self.real_time_data[symbol]:
real_time_candles = list(self.real_time_data[symbol][timeframe])
if real_time_candles:
# Convert to DataFrame
rt_df = pd.DataFrame(real_time_candles)
if historical_df is not None:
# Combine historical and real-time
# Remove overlapping candles from historical data
if not rt_df.empty:
cutoff_time = rt_df['timestamp'].min()
historical_df = historical_df[historical_df['timestamp'] < cutoff_time]
# Concatenate
combined_df = pd.concat([historical_df, rt_df], ignore_index=True)
else:
combined_df = rt_df
return combined_df.tail(limit)
# Return just historical data if no real-time data
return historical_df.tail(limit) if historical_df is not None else pd.DataFrame()
except Exception as e:
logger.error(f"Error getting latest candles for {symbol} {timeframe}: {e}")
return pd.DataFrame()
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol from latest candle"""
try:
# Try to get from 1s candle first (most recent)
for tf in ['1s', '1m', '5m', '1h']:
df = self.get_latest_candles(symbol, tf, limit=1)
if df is not None and not df.empty:
return float(df.iloc[-1]['close'])
# Fallback to any available data
key = f"{symbol}_{self.timeframes[0]}"
if key in self.historical_data and not self.historical_data[key].empty:
return float(self.historical_data[key].iloc[-1]['close'])
logger.warning(f"No price data available for {symbol}")
return None
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return None
def get_price_at_index(self, symbol: str, index: int, timeframe: str = '1m') -> Optional[float]:
"""Get price at specific index for backtesting"""
try:
key = f"{symbol}_{timeframe}"
if key in self.historical_data:
df = self.historical_data[key]
if 0 <= index < len(df):
return float(df.iloc[index]['close'])
return None
except Exception as e:
logger.error(f"Error getting price at index {index}: {e}")
return None
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
window_size: int = 20) -> Optional[np.ndarray]:
"""
Get comprehensive feature matrix for multiple timeframes with technical indicators
Returns:
np.ndarray: Shape (n_timeframes, window_size, n_features)
Each timeframe becomes a separate channel for CNN
"""
try:
if timeframes is None:
timeframes = self.timeframes
feature_channels = []
common_feature_names = None
# First pass: determine common features across all timeframes
timeframe_features = {}
for tf in timeframes:
logger.debug(f"Processing timeframe {tf} for {symbol}")
df = self.get_latest_candles(symbol, tf, limit=window_size + 100)
if df is None or len(df) < window_size:
logger.warning(f"Insufficient data for {symbol} {tf}: {len(df) if df is not None else 0} rows")
continue
# Get feature columns
basic_cols = ['open', 'high', 'low', 'close', 'volume']
indicator_cols = [col for col in df.columns
if col not in basic_cols + ['timestamp'] and not col.startswith('unnamed')]
selected_features = self._select_cnn_features(df, basic_cols, indicator_cols)
timeframe_features[tf] = (df, selected_features)
if common_feature_names is None:
common_feature_names = set(selected_features)
else:
common_feature_names = common_feature_names.intersection(set(selected_features))
if not common_feature_names:
logger.error(f"No common features found across timeframes for {symbol}")
return None
# Convert to sorted list for consistent ordering
common_feature_names = sorted(list(common_feature_names))
logger.info(f"Using {len(common_feature_names)} common features: {common_feature_names}")
# Second pass: create feature channels with common features
for tf in timeframes:
if tf not in timeframe_features:
continue
df, _ = timeframe_features[tf]
# Use only common features
try:
tf_features = self._normalize_features(df[common_feature_names].tail(window_size))
if tf_features is not None and len(tf_features) == window_size:
feature_channels.append(tf_features.values)
logger.debug(f"Added {len(common_feature_names)} features for {tf}")
else:
logger.warning(f"Feature normalization failed for {tf}")
except Exception as e:
logger.error(f"Error processing features for {tf}: {e}")
continue
if not feature_channels:
logger.error(f"No valid feature channels created for {symbol}")
return None
# Verify all channels have the same shape
shapes = [channel.shape for channel in feature_channels]
if len(set(shapes)) > 1:
logger.error(f"Shape mismatch in feature channels: {shapes}")
return None
# Stack all timeframe channels
feature_matrix = np.stack(feature_channels, axis=0)
logger.info(f"Created feature matrix for {symbol}: {feature_matrix.shape} "
f"({len(feature_channels)} timeframes, {window_size} steps, {len(common_feature_names)} features)")
return feature_matrix
except Exception as e:
logger.error(f"Error creating feature matrix for {symbol}: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _select_cnn_features(self, df: pd.DataFrame, basic_cols: List[str], indicator_cols: List[str]) -> List[str]:
"""Select the most important features for CNN training"""
try:
selected = []
# Always include basic OHLCV (normalized)
selected.extend(basic_cols)
# Priority indicators (most informative for CNNs)
priority_indicators = [
# Trend indicators
'sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50',
'macd', 'macd_signal', 'macd_histogram',
'adx', 'adx_pos', 'adx_neg', 'psar',
# Momentum indicators
'rsi_14', 'rsi_7', 'rsi_21',
'stoch_k', 'stoch_d', 'williams_r', 'ultimate_osc',
# Volatility indicators
'bb_upper', 'bb_lower', 'bb_middle', 'bb_width', 'bb_percent',
'atr', 'keltner_upper', 'keltner_lower', 'keltner_middle',
# Volume indicators
'volume_sma_10', 'volume_sma_20', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap',
# Price action
'price_position', 'true_range', 'roc',
# Custom composites
'trend_strength', 'momentum_composite', 'volatility_regime'
]
# Add available priority indicators
for indicator in priority_indicators:
if indicator in indicator_cols:
selected.append(indicator)
# Add any other technical indicators not in priority list (limit to avoid curse of dimensionality)
remaining_indicators = [col for col in indicator_cols if col not in selected]
if remaining_indicators:
# Limit to 10 additional indicators
selected.extend(remaining_indicators[:10])
# Verify all selected features exist in dataframe
final_selected = [col for col in selected if col in df.columns]
logger.debug(f"Selected {len(final_selected)} features from {len(df.columns)} available columns")
return final_selected
except Exception as e:
logger.error(f"Error selecting CNN features: {e}")
return basic_cols # Fallback to basic OHLCV
def _normalize_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""Normalize features for CNN training"""
try:
df_norm = df.copy()
# Handle different normalization strategies for different feature types
for col in df_norm.columns:
if col in ['open', 'high', 'low', 'close', 'sma_10', 'sma_20', 'sma_50',
'ema_12', 'ema_26', 'ema_50', 'bb_upper', 'bb_lower', 'bb_middle',
'keltner_upper', 'keltner_lower', 'keltner_middle', 'psar', 'vwap']:
# Price-based indicators: normalize by close price
if 'close' in df_norm.columns:
base_price = df_norm['close'].iloc[-1] # Use latest close as reference
if base_price > 0:
df_norm[col] = df_norm[col] / base_price
elif col == 'volume':
# Volume: normalize by its own rolling mean
volume_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
if volume_mean > 0:
df_norm[col] = df_norm[col] / volume_mean
elif col in ['rsi_14', 'rsi_7', 'rsi_21']:
# RSI: already 0-100, normalize to 0-1
df_norm[col] = df_norm[col] / 100.0
elif col in ['stoch_k', 'stoch_d']:
# Stochastic: already 0-100, normalize to 0-1
df_norm[col] = df_norm[col] / 100.0
elif col == 'williams_r':
# Williams %R: -100 to 0, normalize to 0-1
df_norm[col] = (df_norm[col] + 100) / 100.0
elif col in ['macd', 'macd_signal', 'macd_histogram']:
# MACD: normalize by ATR or close price
if 'atr' in df_norm.columns and df_norm['atr'].iloc[-1] > 0:
df_norm[col] = df_norm[col] / df_norm['atr'].iloc[-1]
elif 'close' in df_norm.columns:
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
elif col in ['bb_width', 'bb_percent', 'price_position', 'trend_strength',
'momentum_composite', 'volatility_regime']:
# Already normalized indicators: ensure 0-1 range
df_norm[col] = np.clip(df_norm[col], 0, 1)
elif col in ['atr', 'true_range']:
# Volatility indicators: normalize by close price
if 'close' in df_norm.columns:
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
else:
# Other indicators: z-score normalization
col_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
col_std = df_norm[col].rolling(window=min(20, len(df_norm))).std().iloc[-1]
if col_std > 0:
df_norm[col] = (df_norm[col] - col_mean) / col_std
else:
df_norm[col] = 0
# Replace inf/-inf with 0
df_norm = df_norm.replace([np.inf, -np.inf], 0)
# Fill any remaining NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error normalizing features: {e}")
return df
def get_multi_symbol_feature_matrix(self, symbols: List[str] = None,
timeframes: List[str] = None,
window_size: int = 20) -> Optional[np.ndarray]:
"""
Get feature matrix for multiple symbols and timeframes
Returns:
np.ndarray: Shape (n_symbols, n_timeframes, window_size, n_features)
"""
try:
if symbols is None:
symbols = self.symbols
if timeframes is None:
timeframes = self.timeframes
symbol_matrices = []
for symbol in symbols:
symbol_matrix = self.get_feature_matrix(symbol, timeframes, window_size)
if symbol_matrix is not None:
symbol_matrices.append(symbol_matrix)
else:
logger.warning(f"Could not create feature matrix for {symbol}")
if symbol_matrices:
# Stack all symbol matrices
multi_symbol_matrix = np.stack(symbol_matrices, axis=0)
logger.info(f"Created multi-symbol feature matrix: {multi_symbol_matrix.shape}")
return multi_symbol_matrix
return None
except Exception as e:
logger.error(f"Error creating multi-symbol feature matrix: {e}")
return None
def health_check(self) -> Dict[str, Any]:
"""Get health status of the data provider"""
status = {
'streaming': self.is_streaming,
'symbols': len(self.symbols),
'timeframes': len(self.timeframes),
'current_prices': len(self.current_prices),
'websocket_tasks': len(self.websocket_tasks),
'historical_data_loaded': {}
}
# Check historical data availability
for symbol in self.symbols:
status['historical_data_loaded'][symbol] = {}
for tf in self.timeframes:
has_data = (symbol in self.historical_data and
tf in self.historical_data[symbol] and
not self.historical_data[symbol][tf].empty)
status['historical_data_loaded'][symbol][tf] = has_data
return status