gogo2/NN/train_rl.py
2025-04-02 14:03:20 +03:00

503 lines
19 KiB
Python

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import logging
import time
from datetime import datetime
import os
import sys
import pandas as pd
import gym
import json
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from NN.utils.data_interface import DataInterface
from NN.utils.trading_env import TradingEnvironment
from NN.models.dqn_agent import DQNAgent
from NN.utils.signal_interpreter import SignalInterpreter
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('rl_training.log'),
logging.StreamHandler()
]
)
# Set up device for PyTorch (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Log GPU status
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
logger.info(f"Using GPU: {gpu_names}")
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
logger.info("BFloat16 precision is supported - will use for faster training")
else:
logger.warning("GPU not available. Using CPU for training (slower).")
class RLTradingEnvironment(gym.Env):
"""
Reinforcement Learning environment for trading with technical indicators
from multiple timeframes
"""
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__()
# Initialize attributes before parent class
self.window_size = window_size
self.num_features = features_1m.shape[1] - 1 # Exclude close price
self.num_timeframes = 3 # 1m, 5m, 15m
self.feature_dim = self.num_features * self.num_timeframes
# Store features from different timeframes
self.features_1m = features_1m
self.features_5m = features_5m
self.features_15m = features_15m
# Trading parameters
self.initial_balance = 1.0
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.window_size, self.feature_dim),
dtype=np.float32
)
# State variables
self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 0.0 # Amount of asset held
self.current_step = self.window_size
self.trades = 0
self.wins = 0
self.losses = 0
self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation
observation = self._get_observation()
return observation
def _get_observation(self):
"""
Get the current state observation.
Combine features from multiple timeframes, reshaped for the CNN.
"""
# Calculate indices for each timeframe
idx_1m = self.current_step
idx_5m = idx_1m // 5
idx_15m = idx_1m // 15
# Extract feature windows from each timeframe
window_1m = self.features_1m[idx_1m - self.window_size:idx_1m]
# Handle 5m timeframe
start_5m = max(0, idx_5m - self.window_size)
window_5m = self.features_5m[start_5m:idx_5m]
# Handle 15m timeframe
start_15m = max(0, idx_15m - self.window_size)
window_15m = self.features_15m[start_15m:idx_15m]
# Pad if needed (for 5m and 15m)
if len(window_5m) < self.window_size:
padding = np.zeros((self.window_size - len(window_5m), window_5m.shape[1]))
window_5m = np.vstack([padding, window_5m])
if len(window_15m) < self.window_size:
padding = np.zeros((self.window_size - len(window_15m), window_15m.shape[1]))
window_15m = np.vstack([padding, window_15m])
# Combine features from all timeframes
combined_features = np.hstack([
window_1m.reshape(self.window_size, -1),
window_5m.reshape(self.window_size, -1),
window_15m.reshape(self.window_size, -1)
])
# Convert to float32 and handle any NaN values
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
return combined_features
def step(self, action):
"""
Take an action in the environment and return the next state, reward, done flag, and info
Args:
action (int): 0 = Buy, 1 = Sell, 2 = Hold
Returns:
tuple: (observation, reward, done, info)
"""
# Get current and next price
current_price = self.features_1m[self.current_step, -1] # Close price is last column
next_price = self.features_1m[self.current_step + 1, -1]
# Handle zero or negative prices
if current_price <= 0:
current_price = 1e-8 # Small positive number
if next_price <= 0:
next_price = current_price # Use current price if next price is invalid
price_change = (next_price - current_price) / current_price
# Default reward is slightly negative to discourage inaction
reward = -0.0001
done = False
profit_pct = None # Initialize profit_pct variable
# Check if enough time has passed since last trade
trade_interval = self.current_step - self.last_trade_step
trade_interval_penalty = 0
# Execute action
if action == 0: # BUY
if self.position == 0: # Only buy if not already in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
self.position = self.balance * (1 - self.trading_fee)
self.balance = 0
self.trades += 1
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
self.trade_entry_price = current_price
self.last_trade_step = self.current_step
elif action == 1: # SELL
if self.position > 0: # Only sell if in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
# Calculate position value at current price
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
# Calculate profit/loss from trade
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
# Scale reward by profit percentage and apply trade interval penalty
reward = (profit_pct * 10) + trade_interval_penalty
# Update win/loss count
if profit_pct > 0:
self.wins += 1
else:
self.losses += 1
# Record trade
self.trade_history.append({
'entry_price': self.trade_entry_price,
'exit_price': next_price,
'profit_pct': profit_pct,
'trade_interval': trade_interval
})
# Reset position and update last trade step
self.position = 0
self.last_trade_step = self.current_step
# else: (action == 2 - HOLD) - no position change
# Move to next step
self.current_step += 1
# Check if done
if self.current_step >= len(self.features_1m) - 1:
done = True
# Apply final evaluation
if self.position > 0:
# Force close position at the end
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
reward += profit_pct * 10
# Update win/loss count
if profit_pct > 0:
self.wins += 1
else:
self.losses += 1
# Get the next observation
observation = self._get_observation()
# Calculate metrics for info
total_value = self.balance + self.position * next_price
gain = (total_value - self.initial_balance) / self.initial_balance
self.win_rate = self.wins / max(1, self.trades)
info = {
'balance': self.balance,
'position': self.position,
'total_value': total_value,
'gain': gain,
'trades': self.trades,
'win_rate': self.win_rate,
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
'current_price': current_price,
'next_price': next_price
}
# Call the callback if it exists
if self.action_callback:
self.action_callback(action, current_price, reward, info)
return observation, reward, done, info
def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
"""
Train a reinforcement learning agent for trading
Args:
env_class: Optional environment class override
num_episodes: Number of episodes to train for
max_steps: Maximum steps per episode
save_path: Path to save the trained model
action_callback: Callback function for monitoring actions
episode_callback: Callback function for monitoring episodes
symbol: Trading symbol to use
Returns:
tuple: (trained agent, environment)
"""
# Load data for the selected symbol
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
try:
# Try to load data for the requested symbol using get_historical_data method
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
if data_1m is None or data_5m is None or data_15m is None:
raise FileNotFoundError("Could not retrieve data for specified symbol")
except Exception as e:
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default data.")
# Try to use cached data if available
symbol = "BTC/USDT"
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
if data_1m is None or data_5m is None or data_15m is None:
logger.error("Failed to retrieve any data. Cannot continue training.")
raise ValueError("No data available for training")
# Create features from the data by adding technical indicators and converting to numpy format
if data_1m is not None:
data_1m = data_interface.add_technical_indicators(data_1m)
# Convert to numpy array with close price as the last column
features_1m = np.hstack([
data_1m.drop(['timestamp', 'close'], axis=1).values,
data_1m['close'].values.reshape(-1, 1)
])
else:
features_1m = None
if data_5m is not None:
data_5m = data_interface.add_technical_indicators(data_5m)
# Convert to numpy array with close price as the last column
features_5m = np.hstack([
data_5m.drop(['timestamp', 'close'], axis=1).values,
data_5m['close'].values.reshape(-1, 1)
])
else:
features_5m = None
if data_15m is not None:
data_15m = data_interface.add_technical_indicators(data_15m)
# Convert to numpy array with close price as the last column
features_15m = np.hstack([
data_15m.drop(['timestamp', 'close'], axis=1).values,
data_15m['close'].values.reshape(-1, 1)
])
else:
features_15m = None
# Check if we have all the required features
if features_1m is None or features_5m is None or features_15m is None:
logger.error("Failed to create features for all timeframes.")
raise ValueError("Could not create features for training")
# Create the environment
if env_class:
# Use provided environment class
env = env_class(features_1m, features_5m, features_15m)
else:
# Use the default environment
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
# Set action callback if provided
if action_callback:
env.set_action_callback(action_callback)
# Get environment properties for agent creation
input_shape = env.observation_space.shape
n_actions = env.action_space.n
# Create the agent
agent = DQNAgent(
state_shape=input_shape,
n_actions=n_actions,
epsilon=1.0,
epsilon_decay=0.995,
epsilon_min=0.01,
learning_rate=0.0001,
gamma=0.99,
buffer_size=10000,
batch_size=64,
device=device # Pass device to agent for GPU usage
)
# Check if model file exists and load it
model_file = f"{save_path}_model.pth"
if os.path.exists(model_file):
try:
agent.load(model_file)
logger.info(f"Loaded existing model from {model_file}")
except Exception as e:
logger.error(f"Error loading model: {e}")
else:
logger.info("No existing model found. Starting with a new model.")
# Create TensorBoard writer
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
# Log GPU status to TensorBoard
writer.add_text("hardware/device", str(device), 0)
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
# Training loop
total_rewards = []
trade_win_rates = []
best_reward = -np.inf
# Move models to the appropriate device if not already there
agent.move_models_to_device(device)
# Enable mixed precision if GPU and feature is available
use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
logger.info("Enabling mixed precision training")
use_mixed_precision = True
scaler = torch.cuda.amp.GradScaler()
# Define step callback for tensorboard logging and model tracking
def step_callback(action, price, reward, info):
# Pass to external callback if provided
if action_callback:
action_callback(env.current_step, action, price, reward, info)
# Main training loop
logger.info(f"Starting training for {num_episodes} episodes...")
logger.info(f"Starting training on device: {agent.device}")
try:
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for step in range(max_steps):
# Select action
action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done)
# Update state and reward
state = next_state
total_reward += reward
# Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Track rewards
total_rewards.append(total_reward)
# Calculate trading metrics
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
trades = env.trades if hasattr(env, 'trades') else 0
# Log to TensorBoard
writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trade/Count', trades, episode)
# Save best model
if total_reward > best_reward and episode > 10:
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
agent.save(save_path)
best_reward = total_reward
# Periodic save every 100 episodes
if episode % 100 == 0 and episode > 0:
agent.save(f"{save_path}_episode_{episode}")
# Call episode callback if provided
if episode_callback:
# Add environment to info dict to use for extrema training
info_with_env = info.copy()
info_with_env['env'] = env
episode_callback(episode, total_reward, info_with_env)
# Final save
logger.info("Training completed, saving final model")
agent.save(f"{save_path}_final")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer
writer.close()
return agent, env
if __name__ == "__main__":
train_rl()