gogo2/NN/train_rl.py
2025-04-01 21:37:08 +03:00

430 lines
16 KiB
Python

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import logging
import time
from datetime import datetime
import os
import sys
import pandas as pd
import gym
import json
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from NN.utils.data_interface import DataInterface
from NN.utils.trading_env import TradingEnvironment
from NN.models.dqn_agent import DQNAgent
from NN.utils.signal_interpreter import SignalInterpreter
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('rl_training.log'),
logging.StreamHandler()
]
)
class RLTradingEnvironment(gym.Env):
"""
Reinforcement Learning environment for trading with technical indicators
from multiple timeframes
"""
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__()
# Initialize attributes before parent class
self.window_size = window_size
self.num_features = features_1m.shape[1] - 1 # Exclude close price
self.num_timeframes = 3 # 1m, 5m, 15m
self.feature_dim = self.num_features * self.num_timeframes
# Store features from different timeframes
self.features_1m = features_1m
self.features_5m = features_5m
self.features_15m = features_15m
# Trading parameters
self.initial_balance = 1.0
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.window_size, self.feature_dim),
dtype=np.float32
)
# State variables
self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 0.0 # Amount of asset held
self.current_step = self.window_size
self.trades = 0
self.wins = 0
self.losses = 0
self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation
observation = self._get_observation()
return observation
def _get_observation(self):
"""
Get the current state observation.
Combine features from multiple timeframes, reshaped for the CNN.
"""
# Calculate indices for each timeframe
idx_1m = self.current_step
idx_5m = idx_1m // 5
idx_15m = idx_1m // 15
# Extract feature windows from each timeframe
window_1m = self.features_1m[idx_1m - self.window_size:idx_1m]
# Handle 5m timeframe
start_5m = max(0, idx_5m - self.window_size)
window_5m = self.features_5m[start_5m:idx_5m]
# Handle 15m timeframe
start_15m = max(0, idx_15m - self.window_size)
window_15m = self.features_15m[start_15m:idx_15m]
# Pad if needed (for 5m and 15m)
if len(window_5m) < self.window_size:
padding = np.zeros((self.window_size - len(window_5m), window_5m.shape[1]))
window_5m = np.vstack([padding, window_5m])
if len(window_15m) < self.window_size:
padding = np.zeros((self.window_size - len(window_15m), window_15m.shape[1]))
window_15m = np.vstack([padding, window_15m])
# Combine features from all timeframes
combined_features = np.hstack([
window_1m.reshape(self.window_size, -1),
window_5m.reshape(self.window_size, -1),
window_15m.reshape(self.window_size, -1)
])
# Convert to float32 and handle any NaN values
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
return combined_features
def step(self, action):
"""
Take an action in the environment and return the next state, reward, done flag, and info
Args:
action (int): 0 = Buy, 1 = Sell, 2 = Hold
Returns:
tuple: (observation, reward, done, info)
"""
# Get current and next price
current_price = self.features_1m[self.current_step, -1] # Close price is last column
next_price = self.features_1m[self.current_step + 1, -1]
# Handle zero or negative prices
if current_price <= 0:
current_price = 1e-8 # Small positive number
if next_price <= 0:
next_price = current_price # Use current price if next price is invalid
price_change = (next_price - current_price) / current_price
# Default reward is slightly negative to discourage inaction
reward = -0.0001
done = False
profit_pct = None # Initialize profit_pct variable
# Check if enough time has passed since last trade
trade_interval = self.current_step - self.last_trade_step
trade_interval_penalty = 0
# Execute action
if action == 0: # BUY
if self.position == 0: # Only buy if not already in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
self.position = self.balance * (1 - self.trading_fee)
self.balance = 0
self.trades += 1
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
self.trade_entry_price = current_price
self.last_trade_step = self.current_step
elif action == 1: # SELL
if self.position > 0: # Only sell if in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
# Calculate position value at current price
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
# Calculate profit/loss from trade
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
# Scale reward by profit percentage and apply trade interval penalty
reward = (profit_pct * 10) + trade_interval_penalty
# Update win/loss count
if profit_pct > 0:
self.wins += 1
else:
self.losses += 1
# Record trade
self.trade_history.append({
'entry_price': self.trade_entry_price,
'exit_price': next_price,
'profit_pct': profit_pct,
'trade_interval': trade_interval
})
# Reset position and update last trade step
self.position = 0
self.last_trade_step = self.current_step
# else: (action == 2 - HOLD) - no position change
# Move to next step
self.current_step += 1
# Check if done
if self.current_step >= len(self.features_1m) - 1:
done = True
# Apply final evaluation
if self.position > 0:
# Force close position at the end
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
reward += profit_pct * 10
# Update win/loss count
if profit_pct > 0:
self.wins += 1
else:
self.losses += 1
# Get the next observation
observation = self._get_observation()
# Calculate metrics for info
total_value = self.balance + self.position * next_price
gain = (total_value - self.initial_balance) / self.initial_balance
self.win_rate = self.wins / max(1, self.trades)
info = {
'balance': self.balance,
'position': self.position,
'total_value': total_value,
'gain': gain,
'trades': self.trades,
'win_rate': self.win_rate,
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
'current_price': current_price,
'next_price': next_price
}
# Call the callback if it exists
if self.action_callback:
self.action_callback(action, current_price, reward, info)
return observation, reward, done, info
def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
"""
Train DQN agent for RL-based trading with extended training and monitoring
Args:
env_class: Optional environment class to use, defaults to RLTradingEnvironment
num_episodes: Number of episodes to train
max_steps: Maximum steps per episode
save_path: Path to save the model
action_callback: Optional callback for each action (step, action, price, reward, info)
episode_callback: Optional callback after each episode (episode, reward, info)
symbol: Trading pair symbol (e.g., "BTC/USDT")
Returns:
DQNAgent: The trained agent
"""
import pandas as pd
from NN.utils.data_interface import DataInterface
logger.info("Starting DQN training for RL trading")
# Create data interface with specified symbol
data_interface = DataInterface(symbol=symbol)
# Load and preprocess data
logger.info(f"Loading data from multiple timeframes for {symbol}")
features_1m = data_interface.get_training_data("1m", n_candles=2000)
features_5m = data_interface.get_training_data("5m", n_candles=1000)
features_15m = data_interface.get_training_data("15m", n_candles=500)
# Check if we have all the data
if features_1m is None or features_5m is None or features_15m is None:
logger.error("Failed to load training data from one or more timeframes")
return None
# If data is a DataFrame, convert to numpy array excluding the timestamp column
if isinstance(features_1m, pd.DataFrame):
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
if isinstance(features_5m, pd.DataFrame):
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
if isinstance(features_15m, pd.DataFrame):
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
# Initialize environment or use provided class
if env_class is None:
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
else:
env = env_class(features_1m, features_5m, features_15m)
# Set action callback if provided
if action_callback:
def step_callback(action, price, reward, info):
action_callback(env.current_step, action, price, reward, info)
env.set_action_callback(step_callback)
# Initialize agent
window_size = env.window_size
num_features = env.num_features * env.num_timeframes
action_size = env.action_space.n
timeframes = ['1m', '5m', '15m'] # Match the timeframes from the environment
agent = DQNAgent(
state_size=window_size * num_features,
action_size=action_size,
window_size=window_size,
num_features=env.num_features,
timeframes=timeframes,
memory_size=100000,
batch_size=64,
learning_rate=0.0001,
gamma=0.99,
epsilon=1.0,
epsilon_min=0.01,
epsilon_decay=0.995
)
# Training variables
best_reward = -float('inf')
episode_rewards = []
# TensorBoard writer for logging
writer = SummaryWriter(log_dir=f'runs/rl_trading_{int(time.time())}')
# Main training loop
logger.info(f"Starting training for {num_episodes} episodes...")
logger.info(f"Starting training on device: {agent.device}")
try:
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for step in range(max_steps):
# Select action
action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done)
# Update state and reward
state = next_state
total_reward += reward
# Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Track rewards
episode_rewards.append(total_reward)
# Log progress
avg_reward = np.mean(episode_rewards[-100:])
logger.info(f"Episode {episode}/{num_episodes} - Reward: {total_reward:.4f}, " +
f"Avg (100): {avg_reward:.4f}, Epsilon: {agent.epsilon:.4f}")
# Calculate trading metrics
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
trades = env.trades if hasattr(env, 'trades') else 0
# Log to TensorBoard
writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('Reward/Average100', avg_reward, episode)
writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trade/Count', trades, episode)
# Save best model
if avg_reward > best_reward and episode > 10:
logger.info(f"New best average reward: {avg_reward:.4f}, saving model")
agent.save(save_path)
best_reward = avg_reward
# Periodic save every 100 episodes
if episode % 100 == 0 and episode > 0:
agent.save(f"{save_path}_episode_{episode}")
# Call episode callback if provided
if episode_callback:
# Add environment to info dict to use for extrema training
info_with_env = info.copy()
info_with_env['env'] = env
episode_callback(episode, total_reward, info_with_env)
# Final save
logger.info("Training completed, saving final model")
agent.save(f"{save_path}_final")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer
writer.close()
return agent
if __name__ == "__main__":
train_rl()