RL training
This commit is contained in:
@@ -13,6 +13,7 @@ import json
|
||||
import pickle
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sys
|
||||
import ta
|
||||
|
||||
# Add project root to sys.path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -534,3 +535,77 @@ class DataInterface:
|
||||
timestamp = df['timestamp'].iloc[-1]
|
||||
|
||||
return X, timestamp
|
||||
|
||||
def get_training_data(self, timeframe='1m', n_candles=5000):
|
||||
"""
|
||||
Get a consolidated dataframe for RL training with OHLCV and technical indicators
|
||||
|
||||
Args:
|
||||
timeframe (str): Timeframe to use
|
||||
n_candles (int): Number of candles to fetch
|
||||
|
||||
Returns:
|
||||
DataFrame: Combined dataframe with price data and technical indicators
|
||||
"""
|
||||
# Get historical data
|
||||
df = self.get_historical_data(timeframe=timeframe, n_candles=n_candles, use_cache=True)
|
||||
|
||||
if df is None or len(df) < 100: # Minimum required for indicators
|
||||
logger.error(f"Not enough data for RL training (need at least 100 candles)")
|
||||
return None
|
||||
|
||||
# Calculate technical indicators
|
||||
try:
|
||||
# Add RSI (14)
|
||||
df['rsi'] = ta.rsi(df['close'], length=14)
|
||||
|
||||
# Add MACD
|
||||
macd = ta.macd(df['close'])
|
||||
df['macd'] = macd['MACD_12_26_9']
|
||||
df['macd_signal'] = macd['MACDs_12_26_9']
|
||||
df['macd_hist'] = macd['MACDh_12_26_9']
|
||||
|
||||
# Add Bollinger Bands
|
||||
bbands = ta.bbands(df['close'], length=20)
|
||||
df['bb_upper'] = bbands['BBU_20_2.0']
|
||||
df['bb_middle'] = bbands['BBM_20_2.0']
|
||||
df['bb_lower'] = bbands['BBL_20_2.0']
|
||||
|
||||
# Add ATR (Average True Range)
|
||||
df['atr'] = ta.atr(df['high'], df['low'], df['close'], length=14)
|
||||
|
||||
# Add moving averages
|
||||
df['sma_20'] = ta.sma(df['close'], length=20)
|
||||
df['sma_50'] = ta.sma(df['close'], length=50)
|
||||
df['ema_20'] = ta.ema(df['close'], length=20)
|
||||
|
||||
# Add OBV (On-Balance Volume)
|
||||
df['obv'] = ta.obv(df['close'], df['volume'])
|
||||
|
||||
# Add momentum indicators
|
||||
df['mom'] = ta.mom(df['close'], length=10)
|
||||
|
||||
# Normalize price to previous close
|
||||
df['close_norm'] = df['close'] / df['close'].shift(1) - 1
|
||||
df['high_norm'] = df['high'] / df['close'].shift(1) - 1
|
||||
df['low_norm'] = df['low'] / df['close'].shift(1) - 1
|
||||
|
||||
# Volatility features
|
||||
df['volatility'] = df['high'] / df['low'] - 1
|
||||
|
||||
# Volume features
|
||||
df['volume_norm'] = df['volume'] / df['volume'].rolling(20).mean()
|
||||
|
||||
# Calculate returns
|
||||
df['returns_1'] = df['close'].pct_change(1)
|
||||
df['returns_5'] = df['close'].pct_change(5)
|
||||
df['returns_10'] = df['close'].pct_change(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating technical indicators: {str(e)}")
|
||||
return None
|
||||
|
||||
# Drop NaN values
|
||||
df = df.dropna()
|
||||
|
||||
return df
|
||||
|
||||
Reference in New Issue
Block a user