RL training
This commit is contained in:
@ -6,6 +6,8 @@ This package contains utility functions and classes used in the neural network t
|
||||
- Data Interface: Connects to realtime trading data and processes it for the neural network models
|
||||
"""
|
||||
|
||||
from NN.utils.data_interface import DataInterface
|
||||
from .data_interface import DataInterface
|
||||
from .trading_env import TradingEnvironment
|
||||
from .signal_interpreter import SignalInterpreter
|
||||
|
||||
__all__ = ['DataInterface']
|
||||
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']
|
@ -13,6 +13,7 @@ import json
|
||||
import pickle
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sys
|
||||
import ta
|
||||
|
||||
# Add project root to sys.path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@ -534,3 +535,77 @@ class DataInterface:
|
||||
timestamp = df['timestamp'].iloc[-1]
|
||||
|
||||
return X, timestamp
|
||||
|
||||
def get_training_data(self, timeframe='1m', n_candles=5000):
|
||||
"""
|
||||
Get a consolidated dataframe for RL training with OHLCV and technical indicators
|
||||
|
||||
Args:
|
||||
timeframe (str): Timeframe to use
|
||||
n_candles (int): Number of candles to fetch
|
||||
|
||||
Returns:
|
||||
DataFrame: Combined dataframe with price data and technical indicators
|
||||
"""
|
||||
# Get historical data
|
||||
df = self.get_historical_data(timeframe=timeframe, n_candles=n_candles, use_cache=True)
|
||||
|
||||
if df is None or len(df) < 100: # Minimum required for indicators
|
||||
logger.error(f"Not enough data for RL training (need at least 100 candles)")
|
||||
return None
|
||||
|
||||
# Calculate technical indicators
|
||||
try:
|
||||
# Add RSI (14)
|
||||
df['rsi'] = ta.rsi(df['close'], length=14)
|
||||
|
||||
# Add MACD
|
||||
macd = ta.macd(df['close'])
|
||||
df['macd'] = macd['MACD_12_26_9']
|
||||
df['macd_signal'] = macd['MACDs_12_26_9']
|
||||
df['macd_hist'] = macd['MACDh_12_26_9']
|
||||
|
||||
# Add Bollinger Bands
|
||||
bbands = ta.bbands(df['close'], length=20)
|
||||
df['bb_upper'] = bbands['BBU_20_2.0']
|
||||
df['bb_middle'] = bbands['BBM_20_2.0']
|
||||
df['bb_lower'] = bbands['BBL_20_2.0']
|
||||
|
||||
# Add ATR (Average True Range)
|
||||
df['atr'] = ta.atr(df['high'], df['low'], df['close'], length=14)
|
||||
|
||||
# Add moving averages
|
||||
df['sma_20'] = ta.sma(df['close'], length=20)
|
||||
df['sma_50'] = ta.sma(df['close'], length=50)
|
||||
df['ema_20'] = ta.ema(df['close'], length=20)
|
||||
|
||||
# Add OBV (On-Balance Volume)
|
||||
df['obv'] = ta.obv(df['close'], df['volume'])
|
||||
|
||||
# Add momentum indicators
|
||||
df['mom'] = ta.mom(df['close'], length=10)
|
||||
|
||||
# Normalize price to previous close
|
||||
df['close_norm'] = df['close'] / df['close'].shift(1) - 1
|
||||
df['high_norm'] = df['high'] / df['close'].shift(1) - 1
|
||||
df['low_norm'] = df['low'] / df['close'].shift(1) - 1
|
||||
|
||||
# Volatility features
|
||||
df['volatility'] = df['high'] / df['low'] - 1
|
||||
|
||||
# Volume features
|
||||
df['volume_norm'] = df['volume'] / df['volume'].rolling(20).mean()
|
||||
|
||||
# Calculate returns
|
||||
df['returns_1'] = df['close'].pct_change(1)
|
||||
df['returns_5'] = df['close'].pct_change(5)
|
||||
df['returns_10'] = df['close'].pct_change(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating technical indicators: {str(e)}")
|
||||
return None
|
||||
|
||||
# Drop NaN values
|
||||
df = df.dropna()
|
||||
|
||||
return df
|
||||
|
162
NN/utils/trading_env.py
Normal file
162
NN/utils/trading_env.py
Normal file
@ -0,0 +1,162 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym import spaces
|
||||
from typing import Dict, Tuple, List
|
||||
import pandas as pd
|
||||
|
||||
class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Custom trading environment for reinforcement learning
|
||||
"""
|
||||
def __init__(self,
|
||||
data: pd.DataFrame,
|
||||
initial_balance: float = 100.0,
|
||||
fee_rate: float = 0.0002,
|
||||
max_steps: int = 1000):
|
||||
super(TradingEnvironment, self).__init__()
|
||||
|
||||
self.data = data
|
||||
self.initial_balance = initial_balance
|
||||
self.fee_rate = fee_rate
|
||||
self.max_steps = max_steps
|
||||
|
||||
# Action space: 0 (SELL), 1 (HOLD), 2 (BUY)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
|
||||
# Observation space: price data, technical indicators, and account state
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf,
|
||||
high=np.inf,
|
||||
shape=(data.shape[1],), # Number of features
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
"""Reset the environment to initial state"""
|
||||
self.current_step = 0
|
||||
self.balance = self.initial_balance
|
||||
self.position = 0 # 0: no position, 1: long position
|
||||
self.entry_price = 0
|
||||
self.total_trades = 0
|
||||
self.winning_trades = 0
|
||||
self.total_pnl = 0
|
||||
self.balance_history = [self.initial_balance]
|
||||
self.max_balance = self.initial_balance
|
||||
|
||||
return self._get_observation()
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
"""Get current observation state"""
|
||||
return self.data.iloc[self.current_step].values
|
||||
|
||||
def _calculate_reward(self, action: int) -> float:
|
||||
"""Calculate reward based on action and outcome"""
|
||||
current_price = self.data.iloc[self.current_step]['close']
|
||||
|
||||
# If we have an open position
|
||||
if self.position != 0:
|
||||
# Calculate PnL
|
||||
pnl = self.position * (current_price - self.entry_price) / self.entry_price
|
||||
fees = self.fee_rate * 2 # Entry and exit fees
|
||||
|
||||
# Close position
|
||||
if (action == 0 and self.position > 0) or (action == 2 and self.position < 0):
|
||||
net_pnl = pnl - fees
|
||||
self.total_pnl += net_pnl
|
||||
self.balance *= (1 + net_pnl)
|
||||
self.balance_history.append(self.balance)
|
||||
self.max_balance = max(self.max_balance, self.balance)
|
||||
|
||||
self.total_trades += 1
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
|
||||
# Reward based on PnL
|
||||
reward = net_pnl * 100 # Scale up for better learning
|
||||
|
||||
# Additional reward for win rate
|
||||
win_rate = self.winning_trades / max(1, self.total_trades)
|
||||
reward += win_rate * 0.1
|
||||
|
||||
self.position = 0
|
||||
return reward
|
||||
|
||||
# Hold position
|
||||
return pnl * 0.1 # Small reward for holding profitable positions
|
||||
|
||||
# No position
|
||||
if action == 1: # HOLD
|
||||
return 0
|
||||
|
||||
# Open new position
|
||||
if action in [0, 2]: # SELL or BUY
|
||||
self.position = -1 if action == 0 else 1
|
||||
self.entry_price = current_price
|
||||
return -self.fee_rate # Small penalty for trading
|
||||
|
||||
return 0
|
||||
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||
"""Execute one step in the environment"""
|
||||
# Calculate reward
|
||||
reward = self._calculate_reward(action)
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
|
||||
# Check if episode is done
|
||||
done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1)
|
||||
|
||||
# Get next observation
|
||||
observation = self._get_observation()
|
||||
|
||||
# Calculate max drawdown
|
||||
max_drawdown = 0
|
||||
if len(self.balance_history) > 1:
|
||||
peak = self.balance_history[0]
|
||||
for balance in self.balance_history:
|
||||
peak = max(peak, balance)
|
||||
drawdown = (peak - balance) / peak
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
|
||||
# Additional info
|
||||
info = {
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'total_trades': self.total_trades,
|
||||
'win_rate': self.winning_trades / max(1, self.total_trades),
|
||||
'total_pnl': self.total_pnl,
|
||||
'max_drawdown': max_drawdown
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def render(self, mode='human'):
|
||||
"""Render the environment"""
|
||||
if mode == 'human':
|
||||
print(f"Step: {self.current_step}")
|
||||
print(f"Balance: ${self.balance:.2f}")
|
||||
print(f"Position: {self.position}")
|
||||
print(f"Total Trades: {self.total_trades}")
|
||||
print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}")
|
||||
print(f"Total PnL: ${self.total_pnl:.2f}")
|
||||
print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}")
|
||||
print("-" * 50)
|
||||
|
||||
def _calculate_max_drawdown(self):
|
||||
"""Calculate maximum drawdown from balance history"""
|
||||
if len(self.balance_history) <= 1:
|
||||
return 0.0
|
||||
|
||||
peak = self.balance_history[0]
|
||||
max_drawdown = 0.0
|
||||
|
||||
for balance in self.balance_history:
|
||||
peak = max(peak, balance)
|
||||
drawdown = (peak - balance) / peak
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
|
||||
return max_drawdown
|
Reference in New Issue
Block a user