547 lines
20 KiB
Python
547 lines
20 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Improved RL Trading with Enhanced Training and Monitoring
|
|
|
|
This script provides an improved version of the RL training process,
|
|
implementing better normalization, reward structure, and model training.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import argparse
|
|
import time
|
|
from datetime import datetime
|
|
import numpy as np
|
|
import torch
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
|
|
# Add project directory to path if needed
|
|
project_root = os.path.dirname(os.path.abspath(__file__))
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
# Import our custom modules
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from NN.utils.trading_env import TradingEnvironment
|
|
from NN.utils.data_interface import DataInterface
|
|
from dataprovider_realtime import BinanceHistoricalData, RealTimeChart
|
|
|
|
# Configure logging
|
|
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_filename),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger('improved_rl')
|
|
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description='Improved RL Trading with Enhanced Training')
|
|
parser.add_argument('--episodes', type=int, default=20, help='Number of episodes to train')
|
|
parser.add_argument('--visualize', action='store_true', help='Visualize trades during training')
|
|
parser.add_argument('--save-path', type=str, default='NN/models/saved/improved_dqn_agent', help='Path to save trained model')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
|
args = parser.parse_args()
|
|
|
|
def create_training_environment(symbol, window_size=20):
|
|
"""Create and prepare the training environment with data"""
|
|
logger.info(f"Setting up training environment for {symbol}")
|
|
|
|
# Fetch historical data from multiple timeframes
|
|
data_interface = DataInterface(symbol)
|
|
|
|
# Use Binance data provider for fetching data
|
|
historical_data = BinanceHistoricalData()
|
|
|
|
# Fetch data for each timeframe
|
|
df_1m = historical_data.get_historical_candles(symbol, interval_seconds=60, limit=1000)
|
|
df_5m = historical_data.get_historical_candles(symbol, interval_seconds=300, limit=1000)
|
|
df_15m = historical_data.get_historical_candles(symbol, interval_seconds=900, limit=500)
|
|
|
|
# Ensure all dataframes have index as timestamp type
|
|
if df_1m is not None and not df_1m.empty:
|
|
if 'timestamp' in df_1m.columns:
|
|
df_1m = df_1m.set_index('timestamp')
|
|
|
|
if df_5m is not None and not df_5m.empty:
|
|
if 'timestamp' in df_5m.columns:
|
|
df_5m = df_5m.set_index('timestamp')
|
|
|
|
if df_15m is not None and not df_15m.empty:
|
|
if 'timestamp' in df_15m.columns:
|
|
df_15m = df_15m.set_index('timestamp')
|
|
|
|
# Preprocess data (add technical indicators)
|
|
df_1m = preprocess_dataframe(df_1m)
|
|
df_5m = preprocess_dataframe(df_5m)
|
|
df_15m = preprocess_dataframe(df_15m)
|
|
|
|
# Create environment with all timeframes
|
|
env = create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size)
|
|
|
|
return env, (df_1m, df_5m, df_15m)
|
|
|
|
def preprocess_dataframe(df):
|
|
"""Add technical indicators and preprocess dataframe"""
|
|
if df is None or df.empty:
|
|
return None
|
|
|
|
# Drop any missing values
|
|
df = df.dropna()
|
|
|
|
# Ensure it has OHLCV columns
|
|
required_columns = ['open', 'high', 'low', 'close', 'volume']
|
|
missing_columns = [col for col in required_columns if col not in df.columns]
|
|
|
|
if missing_columns:
|
|
logger.warning(f"Missing required columns: {missing_columns}")
|
|
for col in missing_columns:
|
|
# Fill with close price for OHLC if missing
|
|
if col in ['open', 'high', 'low'] and 'close' in df.columns:
|
|
df[col] = df['close']
|
|
# Fill with zeros for volume if missing
|
|
elif col == 'volume':
|
|
df[col] = 0
|
|
|
|
# Add simple technical indicators
|
|
# 1. Simple Moving Averages
|
|
df['sma_5'] = df['close'].rolling(window=5).mean()
|
|
df['sma_10'] = df['close'].rolling(window=10).mean()
|
|
|
|
# 2. Relative Strength Index (RSI)
|
|
delta = df['close'].diff()
|
|
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
|
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
|
rs = gain / loss
|
|
df['rsi'] = 100 - (100 / (1 + rs))
|
|
|
|
# 3. Bollinger Bands
|
|
df['bb_middle'] = df['close'].rolling(window=20).mean()
|
|
df['bb_std'] = df['close'].rolling(window=20).std()
|
|
df['bb_upper'] = df['bb_middle'] + 2 * df['bb_std']
|
|
df['bb_lower'] = df['bb_middle'] - 2 * df['bb_std']
|
|
|
|
# 4. MACD
|
|
df['ema_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
|
df['ema_26'] = df['close'].ewm(span=26, adjust=False).mean()
|
|
df['macd'] = df['ema_12'] - df['ema_26']
|
|
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
|
|
|
# 5. Price rate of change
|
|
df['roc'] = df['close'].pct_change(periods=10) * 100
|
|
|
|
# Fill any remaining NaN values with 0
|
|
df = df.fillna(0)
|
|
|
|
return df
|
|
|
|
def create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size=20):
|
|
"""Create a custom environment that handles multiple timeframes"""
|
|
|
|
# Ensure we have complete data for all timeframes
|
|
min_required_length = window_size + 100 # Add buffer for training
|
|
|
|
if (df_1m is None or len(df_1m) < min_required_length or
|
|
df_5m is None or len(df_5m) < min_required_length or
|
|
df_15m is None or len(df_15m) < min_required_length):
|
|
raise ValueError(f"Insufficient data for training. Need at least {min_required_length} candles per timeframe.")
|
|
|
|
# Ensure we only use the last N valid data points
|
|
df_1m = df_1m.iloc[-900:].copy() if len(df_1m) > 900 else df_1m.copy()
|
|
df_5m = df_5m.iloc[-180:].copy() if len(df_5m) > 180 else df_5m.copy()
|
|
df_15m = df_15m.iloc[-60:].copy() if len(df_15m) > 60 else df_15m.copy()
|
|
|
|
# Reset index to make sure we have continuous integers
|
|
df_1m = df_1m.reset_index(drop=True)
|
|
df_5m = df_5m.reset_index(drop=True)
|
|
df_15m = df_15m.reset_index(drop=True)
|
|
|
|
# For simplicity, we'll use the 1m data as the base environment
|
|
# The other timeframes will be incorporated through observation
|
|
|
|
env = TradingEnvironment(
|
|
data=df_1m,
|
|
initial_balance=100.0,
|
|
fee_rate=0.0005, # 0.05% fee (typical for crypto exchanges)
|
|
max_steps=len(df_1m) - window_size - 50, # Leave some room at the end
|
|
window_size=window_size,
|
|
risk_aversion=0.2, # Moderately risk-averse
|
|
price_scaling='zscore', # Use z-score normalization
|
|
reward_scaling=10.0, # Scale rewards for better learning
|
|
episode_penalty=0.2 # Penalty for holding positions at end of episode
|
|
)
|
|
|
|
return env
|
|
|
|
def initialize_agent(env, window_size=20, num_features=0, timeframes=None):
|
|
"""Initialize the DQN agent with appropriate parameters"""
|
|
if timeframes is None:
|
|
timeframes = ['1m', '5m', '15m']
|
|
|
|
# Calculate input dimensions
|
|
state_dim = env.observation_space.shape[0]
|
|
action_dim = env.action_space.n
|
|
|
|
# If num_features wasn't provided, infer from environment
|
|
if num_features == 0:
|
|
# Calculate features per timeframe from state dimension and number of timeframes
|
|
# Accounting for the 3 additional features (position, equity, unrealized_pnl)
|
|
num_features = (state_dim - 3) // len(timeframes)
|
|
|
|
logger.info(f"Initializing DQN agent: state_dim={state_dim}, action_dim={action_dim}, features={num_features}")
|
|
|
|
agent = DQNAgent(
|
|
state_size=state_dim,
|
|
action_size=action_dim,
|
|
window_size=window_size,
|
|
num_features=num_features,
|
|
timeframes=timeframes,
|
|
learning_rate=0.0005, # Start with a moderate learning rate
|
|
gamma=0.97, # Slightly reduced discount factor for stable learning
|
|
epsilon=1.0, # Start with full exploration
|
|
epsilon_min=0.05, # Maintain some exploration even at the end
|
|
epsilon_decay=0.9975, # Slower decay for more exploration
|
|
memory_size=20000, # Larger replay buffer
|
|
batch_size=128, # Larger batch size for more stable gradients
|
|
target_update=5 # More frequent target network updates
|
|
)
|
|
|
|
return agent
|
|
|
|
def train_agent(env, agent, num_episodes=20, visualize=False, chart=None, save_path=None, save_freq=5):
|
|
"""
|
|
Train the DQN agent with improved training loop
|
|
|
|
Args:
|
|
env: The trading environment
|
|
agent: The DQN agent
|
|
num_episodes: Number of episodes to train
|
|
visualize: Whether to visualize trades during training
|
|
chart: The visualization chart (if visualize=True)
|
|
save_path: Path to save the model
|
|
save_freq: How often to save checkpoints (in episodes)
|
|
|
|
Returns:
|
|
tuple: (rewards, wins, losses, best_reward)
|
|
"""
|
|
logger.info(f"Starting training for {num_episodes} episodes")
|
|
|
|
# Initialize metrics tracking
|
|
rewards = []
|
|
win_rates = []
|
|
total_train_time = 0
|
|
best_reward = float('-inf')
|
|
best_model_path = None
|
|
|
|
# Create directory for checkpoints if needed
|
|
if save_path:
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
checkpoint_dir = os.path.join(os.path.dirname(save_path), 'checkpoints')
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# For tracking improvement
|
|
last_improved_episode = 0
|
|
patience = 10 # Episodes to wait for improvement before early stopping
|
|
|
|
for episode in range(num_episodes):
|
|
start_time = time.time()
|
|
|
|
# Reset environment and get initial state
|
|
state = env.reset()
|
|
done = False
|
|
episode_reward = 0
|
|
step = 0
|
|
|
|
# Action metrics for this episode
|
|
actions_taken = {0: 0, 1: 0, 2: 0} # Track BUY, SELL, HOLD actions
|
|
|
|
while not done:
|
|
# Select action
|
|
action = agent.act(state)
|
|
|
|
# Execute action
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store experience in replay buffer
|
|
is_extrema = False # In a real implementation, detect extrema points
|
|
agent.remember(state, action, reward, next_state, done, is_extrema)
|
|
|
|
# Learn from experience
|
|
if len(agent.memory) >= agent.batch_size:
|
|
use_prioritized = episode > 1 # Start using prioritized replay after first episode
|
|
loss = agent.replay(use_prioritized=use_prioritized)
|
|
|
|
# Update state and metrics
|
|
state = next_state
|
|
episode_reward += reward
|
|
actions_taken[action] += 1
|
|
|
|
# Every 100 steps, log progress
|
|
if step % 100 == 0 or step < 10:
|
|
action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
|
|
current_price = info.get('current_price', 0)
|
|
pnl = info.get('pnl', 0)
|
|
balance = info.get('balance', 0)
|
|
|
|
logger.info(f"Episode {episode}, Step {step}: Action={action_str}, "
|
|
f"Reward={reward:.4f}, Balance=${balance:.2f}, PnL={pnl:.4f}")
|
|
|
|
# Add trade to visualization if enabled
|
|
if visualize and chart and action in [0, 1]: # BUY or SELL
|
|
chart.add_trade(
|
|
price=current_price,
|
|
timestamp=datetime.now(),
|
|
amount=0.1,
|
|
pnl=pnl,
|
|
action=action_str
|
|
)
|
|
|
|
step += 1
|
|
|
|
# Episode finished - calculate metrics
|
|
episode_time = time.time() - start_time
|
|
total_train_time += episode_time
|
|
|
|
# Get environment info
|
|
win_rate = env.winning_trades / max(1, env.total_trades)
|
|
trades = env.total_trades
|
|
balance = env.balance
|
|
gain = (balance - env.initial_balance) / env.initial_balance
|
|
max_drawdown = env.max_drawdown
|
|
|
|
# Record metrics
|
|
rewards.append(episode_reward)
|
|
win_rates.append(win_rate)
|
|
|
|
# Update agent's learning metrics
|
|
improved = agent.update_learning_metrics(episode_reward)
|
|
|
|
# If this is best performance, save the model
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
if save_path:
|
|
best_model_path = f"{save_path}_best"
|
|
agent.save(best_model_path)
|
|
logger.info(f"New best model saved to {best_model_path} (reward: {best_reward:.2f})")
|
|
last_improved_episode = episode
|
|
|
|
# Regular checkpoint saving
|
|
if save_path and episode % save_freq == 0:
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_episode_{episode}")
|
|
agent.save(checkpoint_path)
|
|
|
|
# Print episode summary
|
|
actions_summary = ", ".join([f"{k}:{v}" for k, v in actions_taken.items()])
|
|
logger.info(f"Episode {episode} completed in {episode_time:.2f}s")
|
|
logger.info(f" Total reward: {episode_reward:.4f}")
|
|
logger.info(f" Actions taken: {actions_summary}")
|
|
logger.info(f" Trades: {trades}, Win rate: {win_rate:.2%}")
|
|
logger.info(f" Balance: ${balance:.2f}, Gain: {gain:.2%}")
|
|
logger.info(f" Max Drawdown: {max_drawdown:.2%}")
|
|
|
|
# Early stopping check
|
|
if episode - last_improved_episode >= patience:
|
|
logger.info(f"No improvement for {patience} episodes. Early stopping.")
|
|
break
|
|
|
|
# Training complete
|
|
avg_time_per_episode = total_train_time / max(1, len(rewards))
|
|
logger.info(f"Training completed in {total_train_time:.2f}s ({avg_time_per_episode:.2f}s per episode)")
|
|
|
|
# Save final model
|
|
if save_path:
|
|
agent.save(f"{save_path}_final")
|
|
logger.info(f"Final model saved to {save_path}_final")
|
|
|
|
# Return training metrics
|
|
return rewards, win_rates, best_reward, best_model_path
|
|
|
|
def plot_training_results(rewards, win_rates, save_dir=None):
|
|
"""Plot training metrics and save the figure"""
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
# Plot rewards
|
|
plt.subplot(2, 1, 1)
|
|
plt.plot(rewards, 'b-')
|
|
plt.title('Training Rewards per Episode')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Total Reward')
|
|
plt.grid(True)
|
|
|
|
# Plot win rates
|
|
plt.subplot(2, 1, 2)
|
|
plt.plot(win_rates, 'g-')
|
|
plt.title('Win Rate per Episode')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Win Rate')
|
|
plt.grid(True)
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure if directory provided
|
|
if save_dir:
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
plt.savefig(os.path.join(save_dir, f'training_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))
|
|
|
|
plt.close()
|
|
|
|
def evaluate_agent(env, agent, num_episodes=5, visualize=False, chart=None):
|
|
"""
|
|
Evaluate a trained agent on the environment
|
|
|
|
Args:
|
|
env: The trading environment
|
|
agent: The trained DQN agent
|
|
num_episodes: Number of evaluation episodes
|
|
visualize: Whether to visualize trades
|
|
chart: The visualization chart (if visualize=True)
|
|
|
|
Returns:
|
|
dict: Evaluation metrics
|
|
"""
|
|
logger.info(f"Evaluating agent over {num_episodes} episodes")
|
|
|
|
# Metrics to track
|
|
total_rewards = []
|
|
total_trades = []
|
|
win_rates = []
|
|
sharpe_ratios = []
|
|
sortino_ratios = []
|
|
max_drawdowns = []
|
|
final_balances = []
|
|
|
|
for episode in range(num_episodes):
|
|
# Reset environment
|
|
state = env.reset()
|
|
done = False
|
|
episode_reward = 0
|
|
|
|
# Run episode without exploration
|
|
while not done:
|
|
action = agent.act(state, explore=False) # No exploration during evaluation
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
episode_reward += reward
|
|
state = next_state
|
|
|
|
# Add trade to visualization if enabled
|
|
if visualize and chart and action in [0, 1]: # BUY or SELL
|
|
action_str = "BUY" if action == 0 else "SELL"
|
|
current_price = info.get('current_price', 0)
|
|
pnl = info.get('pnl', 0)
|
|
|
|
chart.add_trade(
|
|
price=current_price,
|
|
timestamp=datetime.now(),
|
|
amount=0.1,
|
|
pnl=pnl,
|
|
action=action_str
|
|
)
|
|
|
|
# Record metrics
|
|
total_rewards.append(episode_reward)
|
|
total_trades.append(env.total_trades)
|
|
win_rates.append(env.winning_trades / max(1, env.total_trades))
|
|
sharpe_ratios.append(info.get('sharpe_ratio', 0))
|
|
sortino_ratios.append(info.get('sortino_ratio', 0))
|
|
max_drawdowns.append(env.max_drawdown)
|
|
final_balances.append(env.balance)
|
|
|
|
logger.info(f"Evaluation episode {episode} - Reward: {episode_reward:.4f}, "
|
|
f"Balance: ${env.balance:.2f}, Win rate: {win_rates[-1]:.2%}")
|
|
|
|
# Calculate average metrics
|
|
avg_reward = np.mean(total_rewards)
|
|
avg_trades = np.mean(total_trades)
|
|
avg_win_rate = np.mean(win_rates)
|
|
avg_sharpe = np.mean(sharpe_ratios)
|
|
avg_sortino = np.mean(sortino_ratios)
|
|
avg_max_drawdown = np.mean(max_drawdowns)
|
|
avg_final_balance = np.mean(final_balances)
|
|
|
|
# Log evaluation summary
|
|
logger.info("Evaluation completed:")
|
|
logger.info(f" Average reward: {avg_reward:.4f}")
|
|
logger.info(f" Average trades per episode: {avg_trades:.2f}")
|
|
logger.info(f" Average win rate: {avg_win_rate:.2%}")
|
|
logger.info(f" Average Sharpe ratio: {avg_sharpe:.4f}")
|
|
logger.info(f" Average Sortino ratio: {avg_sortino:.4f}")
|
|
logger.info(f" Average max drawdown: {avg_max_drawdown:.2%}")
|
|
logger.info(f" Average final balance: ${avg_final_balance:.2f}")
|
|
|
|
# Return evaluation metrics
|
|
return {
|
|
'avg_reward': avg_reward,
|
|
'avg_trades': avg_trades,
|
|
'avg_win_rate': avg_win_rate,
|
|
'avg_sharpe': avg_sharpe,
|
|
'avg_sortino': avg_sortino,
|
|
'avg_max_drawdown': avg_max_drawdown,
|
|
'avg_final_balance': avg_final_balance
|
|
}
|
|
|
|
def main():
|
|
"""Main function to run the improved RL training"""
|
|
start_time = time.time()
|
|
logger.info(f"Starting improved RL training for {args.symbol}")
|
|
|
|
# Create environment
|
|
env, data_frames = create_training_environment(args.symbol)
|
|
|
|
# Initialize visualization if enabled
|
|
chart = None
|
|
if args.visualize:
|
|
logger.info("Initializing visualization chart")
|
|
chart = RealTimeChart(args.symbol)
|
|
time.sleep(2) # Give time for chart to initialize
|
|
|
|
# Initialize agent
|
|
agent = initialize_agent(env)
|
|
|
|
# Train agent
|
|
rewards, win_rates, best_reward, best_model_path = train_agent(
|
|
env=env,
|
|
agent=agent,
|
|
num_episodes=args.episodes,
|
|
visualize=args.visualize,
|
|
chart=chart,
|
|
save_path=args.save_path
|
|
)
|
|
|
|
# Plot training results
|
|
plot_dir = os.path.join(os.path.dirname(args.save_path), 'plots')
|
|
plot_training_results(rewards, win_rates, save_dir=plot_dir)
|
|
|
|
# Evaluate best model
|
|
logger.info("Evaluating best model")
|
|
|
|
# Load best model for evaluation
|
|
if best_model_path:
|
|
best_agent = initialize_agent(env)
|
|
best_agent.load(best_model_path)
|
|
|
|
# Evaluate the best model
|
|
eval_metrics = evaluate_agent(
|
|
env=env,
|
|
agent=best_agent,
|
|
visualize=args.visualize,
|
|
chart=chart
|
|
)
|
|
|
|
# Log evaluation results
|
|
logger.info("Best model evaluation complete:")
|
|
for metric, value in eval_metrics.items():
|
|
logger.info(f" {metric}: {value}")
|
|
|
|
# Total run time
|
|
total_time = time.time() - start_time
|
|
logger.info(f"Total run time: {total_time:.2f} seconds")
|
|
|
|
if __name__ == "__main__":
|
|
main() |