gogo2/NN/realtime_data_interface.py
Dobromir Popov 73c5ecb0d2 enhancements
2025-04-01 13:46:53 +03:00

241 lines
9.7 KiB
Python

import logging
import numpy as np
import pandas as pd
import time
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RealtimeDataInterface:
"""Interface for retrieving real-time market data for neural network models.
This class serves as a bridge between the RealTimeChart data sources and
the neural network models, providing properly formatted data for model
inference.
"""
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
"""Initialize the data interface.
Args:
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
chart: RealTimeChart instance (optional)
max_cache_size: Maximum number of cached candles
"""
self.symbols = symbols
self.chart = chart
self.max_cache_size = max_cache_size
# Initialize data cache
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
n_candles: int = 500) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not symbol:
if len(self.symbols) > 0:
symbol = self.symbols[0]
else:
logger.error("No symbol specified and no default symbols available")
return None
if symbol not in self.symbols:
logger.warning(f"Symbol {symbol} not in tracked symbols")
return None
try:
# Get data from chart if available
if self.chart:
candles = self._get_chart_data(symbol, timeframe, n_candles)
if candles is not None and len(candles) > 0:
return candles
# Fallback to default empty DataFrame
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
except Exception as e:
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
return None
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
"""Get data from the RealTimeChart for the specified symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not self.chart:
return None
# Get chart data using the _get_chart_data method
try:
# Map to interval seconds
interval_map = {
'1s': 1,
'5s': 5,
'10s': 10,
'15s': 15,
'30s': 30,
'1m': 60,
'3m': 180,
'5m': 300,
'15m': 900,
'30m': 1800,
'1h': 3600,
'2h': 7200,
'4h': 14400,
'6h': 21600,
'8h': 28800,
'12h': 43200,
'1d': 86400,
'3d': 259200,
'1w': 604800
}
# Convert timeframe to seconds
if timeframe in interval_map:
interval_seconds = interval_map[timeframe]
else:
# Try to parse the interval (e.g., '1m' -> 60)
try:
if timeframe.endswith('s'):
interval_seconds = int(timeframe[:-1])
elif timeframe.endswith('m'):
interval_seconds = int(timeframe[:-1]) * 60
elif timeframe.endswith('h'):
interval_seconds = int(timeframe[:-1]) * 3600
elif timeframe.endswith('d'):
interval_seconds = int(timeframe[:-1]) * 86400
elif timeframe.endswith('w'):
interval_seconds = int(timeframe[:-1]) * 604800
else:
interval_seconds = int(timeframe)
except ValueError:
logger.error(f"Could not parse timeframe: {timeframe}")
return None
# Get data from chart
df = self.chart._get_chart_data(interval_seconds)
if df is not None and not df.empty:
# Limit to requested number of candles
if len(df) > n_candles:
df = df.iloc[-n_candles:]
return df
else:
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
return None
except Exception as e:
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
return None
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare model input from OHLCV data.
Args:
data: DataFrame with OHLCV data
window_size: Window size for model input
symbol: Symbol for the data (for logging)
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
if data is None or len(data) < window_size:
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
return None, None
try:
# Get last window_size candles
recent_data = data.iloc[-window_size:].copy()
# Get timestamp of the most recent candle
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
# Extract OHLCV features and normalize
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
# Normalize price data by the last close price
last_close = recent_data['close'].iloc[-1]
# Avoid division by zero
if last_close == 0:
last_close = 1.0
opens = (recent_data['open'] / last_close).values
highs = (recent_data['high'] / last_close).values
lows = (recent_data['low'] / last_close).values
closes = (recent_data['close'] / last_close).values
# Normalize volume by the max volume in the window
max_volume = recent_data['volume'].max()
if max_volume == 0:
max_volume = 1.0
volumes = (recent_data['volume'] / max_volume).values
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
X = np.column_stack((opens, highs, lows, closes, volumes))
X = X.reshape(1, window_size, 5)
# Replace any NaN or infinite values
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
return X, timestamp
else:
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
return None, None
except Exception as e:
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
return None, None
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare real-time input for the model.
Args:
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
window_size: Window size for model input
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
# Get data for the main symbol
if len(self.symbols) == 0:
logger.error("No symbols available for real-time input")
return None, None
symbol = self.symbols[0]
try:
# Get historical data
data = self.get_historical_data(symbol, timeframe, n_candles)
if data is None or len(data) < window_size:
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
return None, None
# Prepare model input
return self.prepare_model_input(data, window_size, symbol)
except Exception as e:
logger.error(f"Error preparing real-time input: {str(e)}")
return None, None