gogo2/train_improved_rl.py
2025-05-24 00:59:29 +03:00

547 lines
20 KiB
Python

#!/usr/bin/env python
"""
Improved RL Trading with Enhanced Training and Monitoring
This script provides an improved version of the RL training process,
implementing better normalization, reward structure, and model training.
"""
import os
import sys
import logging
import argparse
import time
from datetime import datetime
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
# Add project directory to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Import our custom modules
from NN.models.dqn_agent import DQNAgent
from NN.utils.trading_env import TradingEnvironment
from NN.utils.data_interface import DataInterface
from dataprovider_realtime import BinanceHistoricalData, RealTimeChart
# Configure logging
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
]
)
logger = logging.getLogger('improved_rl')
# Parse command line arguments
parser = argparse.ArgumentParser(description='Improved RL Trading with Enhanced Training')
parser.add_argument('--episodes', type=int, default=20, help='Number of episodes to train')
parser.add_argument('--visualize', action='store_true', help='Visualize trades during training')
parser.add_argument('--save-path', type=str, default='NN/models/saved/improved_dqn_agent', help='Path to save trained model')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
args = parser.parse_args()
def create_training_environment(symbol, window_size=20):
"""Create and prepare the training environment with data"""
logger.info(f"Setting up training environment for {symbol}")
# Fetch historical data from multiple timeframes
data_interface = DataInterface(symbol)
# Use Binance data provider for fetching data
historical_data = BinanceHistoricalData()
# Fetch data for each timeframe
df_1m = historical_data.get_historical_candles(symbol, interval_seconds=60, limit=1000)
df_5m = historical_data.get_historical_candles(symbol, interval_seconds=300, limit=1000)
df_15m = historical_data.get_historical_candles(symbol, interval_seconds=900, limit=500)
# Ensure all dataframes have index as timestamp type
if df_1m is not None and not df_1m.empty:
if 'timestamp' in df_1m.columns:
df_1m = df_1m.set_index('timestamp')
if df_5m is not None and not df_5m.empty:
if 'timestamp' in df_5m.columns:
df_5m = df_5m.set_index('timestamp')
if df_15m is not None and not df_15m.empty:
if 'timestamp' in df_15m.columns:
df_15m = df_15m.set_index('timestamp')
# Preprocess data (add technical indicators)
df_1m = preprocess_dataframe(df_1m)
df_5m = preprocess_dataframe(df_5m)
df_15m = preprocess_dataframe(df_15m)
# Create environment with all timeframes
env = create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size)
return env, (df_1m, df_5m, df_15m)
def preprocess_dataframe(df):
"""Add technical indicators and preprocess dataframe"""
if df is None or df.empty:
return None
# Drop any missing values
df = df.dropna()
# Ensure it has OHLCV columns
required_columns = ['open', 'high', 'low', 'close', 'volume']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
logger.warning(f"Missing required columns: {missing_columns}")
for col in missing_columns:
# Fill with close price for OHLC if missing
if col in ['open', 'high', 'low'] and 'close' in df.columns:
df[col] = df['close']
# Fill with zeros for volume if missing
elif col == 'volume':
df[col] = 0
# Add simple technical indicators
# 1. Simple Moving Averages
df['sma_5'] = df['close'].rolling(window=5).mean()
df['sma_10'] = df['close'].rolling(window=10).mean()
# 2. Relative Strength Index (RSI)
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df['rsi'] = 100 - (100 / (1 + rs))
# 3. Bollinger Bands
df['bb_middle'] = df['close'].rolling(window=20).mean()
df['bb_std'] = df['close'].rolling(window=20).std()
df['bb_upper'] = df['bb_middle'] + 2 * df['bb_std']
df['bb_lower'] = df['bb_middle'] - 2 * df['bb_std']
# 4. MACD
df['ema_12'] = df['close'].ewm(span=12, adjust=False).mean()
df['ema_26'] = df['close'].ewm(span=26, adjust=False).mean()
df['macd'] = df['ema_12'] - df['ema_26']
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
# 5. Price rate of change
df['roc'] = df['close'].pct_change(periods=10) * 100
# Fill any remaining NaN values with 0
df = df.fillna(0)
return df
def create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size=20):
"""Create a custom environment that handles multiple timeframes"""
# Ensure we have complete data for all timeframes
min_required_length = window_size + 100 # Add buffer for training
if (df_1m is None or len(df_1m) < min_required_length or
df_5m is None or len(df_5m) < min_required_length or
df_15m is None or len(df_15m) < min_required_length):
raise ValueError(f"Insufficient data for training. Need at least {min_required_length} candles per timeframe.")
# Ensure we only use the last N valid data points
df_1m = df_1m.iloc[-900:].copy() if len(df_1m) > 900 else df_1m.copy()
df_5m = df_5m.iloc[-180:].copy() if len(df_5m) > 180 else df_5m.copy()
df_15m = df_15m.iloc[-60:].copy() if len(df_15m) > 60 else df_15m.copy()
# Reset index to make sure we have continuous integers
df_1m = df_1m.reset_index(drop=True)
df_5m = df_5m.reset_index(drop=True)
df_15m = df_15m.reset_index(drop=True)
# For simplicity, we'll use the 1m data as the base environment
# The other timeframes will be incorporated through observation
env = TradingEnvironment(
data=df_1m,
initial_balance=100.0,
fee_rate=0.0005, # 0.05% fee (typical for crypto exchanges)
max_steps=len(df_1m) - window_size - 50, # Leave some room at the end
window_size=window_size,
risk_aversion=0.2, # Moderately risk-averse
price_scaling='zscore', # Use z-score normalization
reward_scaling=10.0, # Scale rewards for better learning
episode_penalty=0.2 # Penalty for holding positions at end of episode
)
return env
def initialize_agent(env, window_size=20, num_features=0, timeframes=None):
"""Initialize the DQN agent with appropriate parameters"""
if timeframes is None:
timeframes = ['1m', '5m', '15m']
# Calculate input dimensions
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# If num_features wasn't provided, infer from environment
if num_features == 0:
# Calculate features per timeframe from state dimension and number of timeframes
# Accounting for the 3 additional features (position, equity, unrealized_pnl)
num_features = (state_dim - 3) // len(timeframes)
logger.info(f"Initializing DQN agent: state_dim={state_dim}, action_dim={action_dim}, features={num_features}")
agent = DQNAgent(
state_size=state_dim,
action_size=action_dim,
window_size=window_size,
num_features=num_features,
timeframes=timeframes,
learning_rate=0.0005, # Start with a moderate learning rate
gamma=0.97, # Slightly reduced discount factor for stable learning
epsilon=1.0, # Start with full exploration
epsilon_min=0.05, # Maintain some exploration even at the end
epsilon_decay=0.9975, # Slower decay for more exploration
memory_size=20000, # Larger replay buffer
batch_size=128, # Larger batch size for more stable gradients
target_update=5 # More frequent target network updates
)
return agent
def train_agent(env, agent, num_episodes=20, visualize=False, chart=None, save_path=None, save_freq=5):
"""
Train the DQN agent with improved training loop
Args:
env: The trading environment
agent: The DQN agent
num_episodes: Number of episodes to train
visualize: Whether to visualize trades during training
chart: The visualization chart (if visualize=True)
save_path: Path to save the model
save_freq: How often to save checkpoints (in episodes)
Returns:
tuple: (rewards, wins, losses, best_reward)
"""
logger.info(f"Starting training for {num_episodes} episodes")
# Initialize metrics tracking
rewards = []
win_rates = []
total_train_time = 0
best_reward = float('-inf')
best_model_path = None
# Create directory for checkpoints if needed
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
checkpoint_dir = os.path.join(os.path.dirname(save_path), 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# For tracking improvement
last_improved_episode = 0
patience = 10 # Episodes to wait for improvement before early stopping
for episode in range(num_episodes):
start_time = time.time()
# Reset environment and get initial state
state = env.reset()
done = False
episode_reward = 0
step = 0
# Action metrics for this episode
actions_taken = {0: 0, 1: 0, 2: 0} # Track BUY, SELL, HOLD actions
while not done:
# Select action
action = agent.act(state)
# Execute action
next_state, reward, done, info = env.step(action)
# Store experience in replay buffer
is_extrema = False # In a real implementation, detect extrema points
agent.remember(state, action, reward, next_state, done, is_extrema)
# Learn from experience
if len(agent.memory) >= agent.batch_size:
use_prioritized = episode > 1 # Start using prioritized replay after first episode
loss = agent.replay(use_prioritized=use_prioritized)
# Update state and metrics
state = next_state
episode_reward += reward
actions_taken[action] += 1
# Every 100 steps, log progress
if step % 100 == 0 or step < 10:
action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
current_price = info.get('current_price', 0)
pnl = info.get('pnl', 0)
balance = info.get('balance', 0)
logger.info(f"Episode {episode}, Step {step}: Action={action_str}, "
f"Reward={reward:.4f}, Balance=${balance:.2f}, PnL={pnl:.4f}")
# Add trade to visualization if enabled
if visualize and chart and action in [0, 1]: # BUY or SELL
chart.add_trade(
price=current_price,
timestamp=datetime.now(),
amount=0.1,
pnl=pnl,
action=action_str
)
step += 1
# Episode finished - calculate metrics
episode_time = time.time() - start_time
total_train_time += episode_time
# Get environment info
win_rate = env.winning_trades / max(1, env.total_trades)
trades = env.total_trades
balance = env.balance
gain = (balance - env.initial_balance) / env.initial_balance
max_drawdown = env.max_drawdown
# Record metrics
rewards.append(episode_reward)
win_rates.append(win_rate)
# Update agent's learning metrics
improved = agent.update_learning_metrics(episode_reward)
# If this is best performance, save the model
if episode_reward > best_reward:
best_reward = episode_reward
if save_path:
best_model_path = f"{save_path}_best"
agent.save(best_model_path)
logger.info(f"New best model saved to {best_model_path} (reward: {best_reward:.2f})")
last_improved_episode = episode
# Regular checkpoint saving
if save_path and episode % save_freq == 0:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_episode_{episode}")
agent.save(checkpoint_path)
# Print episode summary
actions_summary = ", ".join([f"{k}:{v}" for k, v in actions_taken.items()])
logger.info(f"Episode {episode} completed in {episode_time:.2f}s")
logger.info(f" Total reward: {episode_reward:.4f}")
logger.info(f" Actions taken: {actions_summary}")
logger.info(f" Trades: {trades}, Win rate: {win_rate:.2%}")
logger.info(f" Balance: ${balance:.2f}, Gain: {gain:.2%}")
logger.info(f" Max Drawdown: {max_drawdown:.2%}")
# Early stopping check
if episode - last_improved_episode >= patience:
logger.info(f"No improvement for {patience} episodes. Early stopping.")
break
# Training complete
avg_time_per_episode = total_train_time / max(1, len(rewards))
logger.info(f"Training completed in {total_train_time:.2f}s ({avg_time_per_episode:.2f}s per episode)")
# Save final model
if save_path:
agent.save(f"{save_path}_final")
logger.info(f"Final model saved to {save_path}_final")
# Return training metrics
return rewards, win_rates, best_reward, best_model_path
def plot_training_results(rewards, win_rates, save_dir=None):
"""Plot training metrics and save the figure"""
plt.figure(figsize=(12, 8))
# Plot rewards
plt.subplot(2, 1, 1)
plt.plot(rewards, 'b-')
plt.title('Training Rewards per Episode')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
# Plot win rates
plt.subplot(2, 1, 2)
plt.plot(win_rates, 'g-')
plt.title('Win Rate per Episode')
plt.xlabel('Episode')
plt.ylabel('Win Rate')
plt.grid(True)
plt.tight_layout()
# Save figure if directory provided
if save_dir:
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(save_dir, f'training_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))
plt.close()
def evaluate_agent(env, agent, num_episodes=5, visualize=False, chart=None):
"""
Evaluate a trained agent on the environment
Args:
env: The trading environment
agent: The trained DQN agent
num_episodes: Number of evaluation episodes
visualize: Whether to visualize trades
chart: The visualization chart (if visualize=True)
Returns:
dict: Evaluation metrics
"""
logger.info(f"Evaluating agent over {num_episodes} episodes")
# Metrics to track
total_rewards = []
total_trades = []
win_rates = []
sharpe_ratios = []
sortino_ratios = []
max_drawdowns = []
final_balances = []
for episode in range(num_episodes):
# Reset environment
state = env.reset()
done = False
episode_reward = 0
# Run episode without exploration
while not done:
action = agent.act(state, explore=False) # No exploration during evaluation
next_state, reward, done, info = env.step(action)
episode_reward += reward
state = next_state
# Add trade to visualization if enabled
if visualize and chart and action in [0, 1]: # BUY or SELL
action_str = "BUY" if action == 0 else "SELL"
current_price = info.get('current_price', 0)
pnl = info.get('pnl', 0)
chart.add_trade(
price=current_price,
timestamp=datetime.now(),
amount=0.1,
pnl=pnl,
action=action_str
)
# Record metrics
total_rewards.append(episode_reward)
total_trades.append(env.total_trades)
win_rates.append(env.winning_trades / max(1, env.total_trades))
sharpe_ratios.append(info.get('sharpe_ratio', 0))
sortino_ratios.append(info.get('sortino_ratio', 0))
max_drawdowns.append(env.max_drawdown)
final_balances.append(env.balance)
logger.info(f"Evaluation episode {episode} - Reward: {episode_reward:.4f}, "
f"Balance: ${env.balance:.2f}, Win rate: {win_rates[-1]:.2%}")
# Calculate average metrics
avg_reward = np.mean(total_rewards)
avg_trades = np.mean(total_trades)
avg_win_rate = np.mean(win_rates)
avg_sharpe = np.mean(sharpe_ratios)
avg_sortino = np.mean(sortino_ratios)
avg_max_drawdown = np.mean(max_drawdowns)
avg_final_balance = np.mean(final_balances)
# Log evaluation summary
logger.info("Evaluation completed:")
logger.info(f" Average reward: {avg_reward:.4f}")
logger.info(f" Average trades per episode: {avg_trades:.2f}")
logger.info(f" Average win rate: {avg_win_rate:.2%}")
logger.info(f" Average Sharpe ratio: {avg_sharpe:.4f}")
logger.info(f" Average Sortino ratio: {avg_sortino:.4f}")
logger.info(f" Average max drawdown: {avg_max_drawdown:.2%}")
logger.info(f" Average final balance: ${avg_final_balance:.2f}")
# Return evaluation metrics
return {
'avg_reward': avg_reward,
'avg_trades': avg_trades,
'avg_win_rate': avg_win_rate,
'avg_sharpe': avg_sharpe,
'avg_sortino': avg_sortino,
'avg_max_drawdown': avg_max_drawdown,
'avg_final_balance': avg_final_balance
}
def main():
"""Main function to run the improved RL training"""
start_time = time.time()
logger.info(f"Starting improved RL training for {args.symbol}")
# Create environment
env, data_frames = create_training_environment(args.symbol)
# Initialize visualization if enabled
chart = None
if args.visualize:
logger.info("Initializing visualization chart")
chart = RealTimeChart(args.symbol)
time.sleep(2) # Give time for chart to initialize
# Initialize agent
agent = initialize_agent(env)
# Train agent
rewards, win_rates, best_reward, best_model_path = train_agent(
env=env,
agent=agent,
num_episodes=args.episodes,
visualize=args.visualize,
chart=chart,
save_path=args.save_path
)
# Plot training results
plot_dir = os.path.join(os.path.dirname(args.save_path), 'plots')
plot_training_results(rewards, win_rates, save_dir=plot_dir)
# Evaluate best model
logger.info("Evaluating best model")
# Load best model for evaluation
if best_model_path:
best_agent = initialize_agent(env)
best_agent.load(best_model_path)
# Evaluate the best model
eval_metrics = evaluate_agent(
env=env,
agent=best_agent,
visualize=args.visualize,
chart=chart
)
# Log evaluation results
logger.info("Best model evaluation complete:")
for metric, value in eval_metrics.items():
logger.info(f" {metric}: {value}")
# Total run time
total_time = time.time() - start_time
logger.info(f"Total run time: {total_time:.2f} seconds")
if __name__ == "__main__":
main()