241 lines
9.7 KiB
Python
241 lines
9.7 KiB
Python
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
import time
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RealtimeDataInterface:
|
|
"""Interface for retrieving real-time market data for neural network models.
|
|
|
|
This class serves as a bridge between the RealTimeChart data sources and
|
|
the neural network models, providing properly formatted data for model
|
|
inference.
|
|
"""
|
|
|
|
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
|
|
"""Initialize the data interface.
|
|
|
|
Args:
|
|
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
|
chart: RealTimeChart instance (optional)
|
|
max_cache_size: Maximum number of cached candles
|
|
"""
|
|
self.symbols = symbols
|
|
self.chart = chart
|
|
self.max_cache_size = max_cache_size
|
|
|
|
# Initialize data cache
|
|
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
|
|
|
|
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
|
|
|
|
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
|
|
n_candles: int = 500) -> Optional[pd.DataFrame]:
|
|
"""Get historical OHLCV data for a symbol and timeframe.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
|
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
|
n_candles: Number of candles to retrieve
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data or None if not available
|
|
"""
|
|
if not symbol:
|
|
if len(self.symbols) > 0:
|
|
symbol = self.symbols[0]
|
|
else:
|
|
logger.error("No symbol specified and no default symbols available")
|
|
return None
|
|
|
|
if symbol not in self.symbols:
|
|
logger.warning(f"Symbol {symbol} not in tracked symbols")
|
|
return None
|
|
|
|
try:
|
|
# Get data from chart if available
|
|
if self.chart:
|
|
candles = self._get_chart_data(symbol, timeframe, n_candles)
|
|
if candles is not None and len(candles) > 0:
|
|
return candles
|
|
|
|
# Fallback to default empty DataFrame
|
|
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
|
|
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
|
|
return None
|
|
|
|
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
|
|
"""Get data from the RealTimeChart for the specified symbol and timeframe.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
|
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
|
n_candles: Number of candles to retrieve
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data or None if not available
|
|
"""
|
|
if not self.chart:
|
|
return None
|
|
|
|
# Get chart data using the _get_chart_data method
|
|
try:
|
|
# Map to interval seconds
|
|
interval_map = {
|
|
'1s': 1,
|
|
'5s': 5,
|
|
'10s': 10,
|
|
'15s': 15,
|
|
'30s': 30,
|
|
'1m': 60,
|
|
'3m': 180,
|
|
'5m': 300,
|
|
'15m': 900,
|
|
'30m': 1800,
|
|
'1h': 3600,
|
|
'2h': 7200,
|
|
'4h': 14400,
|
|
'6h': 21600,
|
|
'8h': 28800,
|
|
'12h': 43200,
|
|
'1d': 86400,
|
|
'3d': 259200,
|
|
'1w': 604800
|
|
}
|
|
|
|
# Convert timeframe to seconds
|
|
if timeframe in interval_map:
|
|
interval_seconds = interval_map[timeframe]
|
|
else:
|
|
# Try to parse the interval (e.g., '1m' -> 60)
|
|
try:
|
|
if timeframe.endswith('s'):
|
|
interval_seconds = int(timeframe[:-1])
|
|
elif timeframe.endswith('m'):
|
|
interval_seconds = int(timeframe[:-1]) * 60
|
|
elif timeframe.endswith('h'):
|
|
interval_seconds = int(timeframe[:-1]) * 3600
|
|
elif timeframe.endswith('d'):
|
|
interval_seconds = int(timeframe[:-1]) * 86400
|
|
elif timeframe.endswith('w'):
|
|
interval_seconds = int(timeframe[:-1]) * 604800
|
|
else:
|
|
interval_seconds = int(timeframe)
|
|
except ValueError:
|
|
logger.error(f"Could not parse timeframe: {timeframe}")
|
|
return None
|
|
|
|
# Get data from chart
|
|
df = self.chart._get_chart_data(interval_seconds)
|
|
|
|
if df is not None and not df.empty:
|
|
# Limit to requested number of candles
|
|
if len(df) > n_candles:
|
|
df = df.iloc[-n_candles:]
|
|
|
|
return df
|
|
else:
|
|
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
|
|
return None
|
|
|
|
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
|
|
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
|
|
"""Prepare model input from OHLCV data.
|
|
|
|
Args:
|
|
data: DataFrame with OHLCV data
|
|
window_size: Window size for model input
|
|
symbol: Symbol for the data (for logging)
|
|
|
|
Returns:
|
|
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
|
|
"""
|
|
if data is None or len(data) < window_size:
|
|
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
|
|
return None, None
|
|
|
|
try:
|
|
# Get last window_size candles
|
|
recent_data = data.iloc[-window_size:].copy()
|
|
|
|
# Get timestamp of the most recent candle
|
|
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
|
|
|
|
# Extract OHLCV features and normalize
|
|
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
|
|
# Normalize price data by the last close price
|
|
last_close = recent_data['close'].iloc[-1]
|
|
|
|
# Avoid division by zero
|
|
if last_close == 0:
|
|
last_close = 1.0
|
|
|
|
opens = (recent_data['open'] / last_close).values
|
|
highs = (recent_data['high'] / last_close).values
|
|
lows = (recent_data['low'] / last_close).values
|
|
closes = (recent_data['close'] / last_close).values
|
|
|
|
# Normalize volume by the max volume in the window
|
|
max_volume = recent_data['volume'].max()
|
|
if max_volume == 0:
|
|
max_volume = 1.0
|
|
volumes = (recent_data['volume'] / max_volume).values
|
|
|
|
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
|
|
X = np.column_stack((opens, highs, lows, closes, volumes))
|
|
X = X.reshape(1, window_size, 5)
|
|
|
|
# Replace any NaN or infinite values
|
|
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
|
|
|
|
return X, timestamp
|
|
else:
|
|
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
|
|
return None, None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
|
|
return None, None
|
|
|
|
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
|
|
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
|
|
"""Prepare real-time input for the model.
|
|
|
|
Args:
|
|
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
|
n_candles: Number of candles to retrieve
|
|
window_size: Window size for model input
|
|
|
|
Returns:
|
|
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
|
|
"""
|
|
# Get data for the main symbol
|
|
if len(self.symbols) == 0:
|
|
logger.error("No symbols available for real-time input")
|
|
return None, None
|
|
|
|
symbol = self.symbols[0]
|
|
|
|
try:
|
|
# Get historical data
|
|
data = self.get_historical_data(symbol, timeframe, n_candles)
|
|
|
|
if data is None or len(data) < window_size:
|
|
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
|
|
return None, None
|
|
|
|
# Prepare model input
|
|
return self.prepare_model_input(data, window_size, symbol)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error preparing real-time input: {str(e)}")
|
|
return None, None |