gogo2/train_dqn.py
Dobromir Popov c0872248ab misc
2025-05-13 17:19:52 +03:00

415 lines
16 KiB
Python

#!/usr/bin/env python
"""
DQN Training Session with Monitoring
This script sets up and runs a DQN agent training session with progress monitoring.
It tracks key metrics like rewards, losses, and prediction accuracy, and
visualizes the agent's learning progress.
"""
import os
import sys
import logging
import time
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
import signal
from torch.utils.tensorboard import SummaryWriter
# Add project root to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Import configurations
import train_config
# Import key components
from NN.models.dqn_agent import DQNAgent
from realtime import MultiTimeframeDataInterface
# Configure logging
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"dqn_training_{timestamp}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('dqn_training')
# Global variables for graceful shutdown
running = True
# Configure signal handler for graceful shutdown
def signal_handler(sig, frame):
global running
logger.info("Received interrupt signal. Finishing current episode and saving model...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
class DQNTrainingMonitor:
"""
Class to monitor DQN training progress and visualize results
"""
def __init__(self, config):
self.config = config
self.device = torch.device(config['hardware']['device'])
self.agent = None
self.data_interface = None
# Training stats
self.episode_rewards = []
self.avg_rewards = []
self.losses = []
self.epsilons = []
self.best_reward = -float('inf')
self.tensorboard_writer = None
# Paths
self.models_dir = Path(config['paths']['models_dir'])
self.models_dir.mkdir(exist_ok=True, parents=True)
# Metrics display intervals
self.plot_interval = config.get('visualization', {}).get('plot_interval', 5)
self.save_interval = config.get('training', {}).get('save_interval', 10)
def initialize(self):
"""Initialize the DQN agent and data interface"""
# Set up TensorBoard
tb_dir = Path(self.config['paths']['tensorboard_dir'])
tb_dir.mkdir(exist_ok=True, parents=True)
log_dir = tb_dir / f"dqn_{timestamp}"
self.tensorboard_writer = SummaryWriter(log_dir=str(log_dir))
logger.info(f"TensorBoard initialized at {log_dir}")
# Initialize data interface
symbol = self.config['market_data']['symbol']
timeframes = self.config['market_data']['timeframes']
window_size = self.config['market_data']['window_size']
logger.info(f"Initializing data interface for {symbol} with timeframes {timeframes}")
self.data_interface = MultiTimeframeDataInterface(
symbol=symbol,
timeframes=timeframes
)
# Get data for training
X_train_dict, _, _, _, _, _ = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None:
raise ValueError("Failed to load training data for DQN agent")
# Get feature count from the reference timeframe
reference_tf = min(
timeframes,
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
num_features = X_train_dict[reference_tf].shape[2]
logger.info(f"Using {num_features} features from timeframe {reference_tf}")
# Initialize DQN agent
state_size = num_features * window_size * len(timeframes)
action_size = 3 # Buy, Hold, Sell
logger.info(f"Initializing DQN agent with state size {state_size} and action size {action_size}")
self.agent = DQNAgent(
state_shape=(len(timeframes), window_size, num_features), # Multi-dimensional state shape
n_actions=action_size,
learning_rate=self.config['training']['learning_rate'],
batch_size=self.config['training']['batch_size'],
gamma=self.config.get('model', {}).get('gamma', 0.95),
epsilon=self.config.get('model', {}).get('epsilon_start', 1.0),
epsilon_min=self.config.get('model', {}).get('epsilon_min', 0.01),
epsilon_decay=self.config.get('model', {}).get('epsilon_decay', 0.995),
buffer_size=self.config.get('model', {}).get('memory_size', 10000),
device=self.device
)
# Load existing model if available
model_path = self.models_dir / "dqn_agent_best"
if os.path.exists(f"{model_path}_policy.pt") and not self.config.get('model', {}).get('new_model', False):
logger.info(f"Loading existing DQN model from {model_path}")
try:
self.agent.load(str(model_path))
logger.info("DQN model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.info("Starting with a new model instead")
else:
logger.info("Starting with a new model")
return True
def train(self, num_episodes=100):
"""Train the DQN agent for a specified number of episodes"""
if self.agent is None:
raise ValueError("Agent not initialized. Call initialize() first.")
logger.info(f"Starting DQN training for {num_episodes} episodes")
# Get training data
window_size = self.config['market_data']['window_size']
X_train_dict, y_train, _, _, _, _ = self.data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
# Prepare data for training
reference_tf = min(
self.config['market_data']['timeframes'],
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
)
# Convert data to flat states for RL
states = []
actions = []
# Find the minimum length across all timeframes to ensure consistent indexing
min_length = min(len(X_train_dict[tf]) for tf in self.config['market_data']['timeframes'])
logger.info(f"Using {min_length} samples from each timeframe for training")
# Only use indices that exist in all timeframes
for i in range(min_length):
state = []
for tf in self.config['market_data']['timeframes']:
state.extend(X_train_dict[tf][i].flatten())
states.append(np.array(state))
actions.append(np.argmax(y_train[i]))
logger.info(f"Prepared {len(states)} state-action pairs for training")
# Training loop
global running
for episode in range(1, num_episodes + 1):
if not running:
logger.info("Training interrupted. Saving final model.")
self._save_model(final=True)
break
episode_reward = 0
total_loss = 0
correct_predictions = 0
# Randomly sample start position (to prevent overfitting on sequence)
start_idx = np.random.randint(0, len(states) - 1000) if len(states) > 1000 else 0
end_idx = min(start_idx + 1000, len(states))
logger.info(f"Episode {episode}/{num_episodes} - Training on sequence from {start_idx} to {end_idx}")
# Training on sequence
for i in range(start_idx, end_idx - 1):
state = states[i]
action = actions[i]
next_state = states[i + 1]
# Get reward based on price movement
# Price is typically the close price (4th column in OHLCV data)
try:
# Assuming the last feature in each timeframe is the closing price
price_current = X_train_dict[reference_tf][i][-1, -1] # Last row, last column of current state
price_next = X_train_dict[reference_tf][i+1][-1, -1] # Last row, last column of next state
price_diff = price_next - price_current
except IndexError:
# Fallback if we're at the edge of our data
price_diff = 0
if action == 0: # Buy
reward = price_diff * 100 # Scale reward for better learning
elif action == 2: # Sell
reward = -price_diff * 100
else: # Hold
reward = abs(price_diff) * 10 if abs(price_diff) < 0.0001 else -abs(price_diff) * 50
# Train the agent with this experience
predicted_action = self.agent.act(state)
# Store experience in memory
done = (i == end_idx - 2) # Mark as done if it's the last step
self.agent.remember(state, action, reward, next_state, done)
# Periodically replay from memory
if i % 10 == 0: # Replay every 10 steps
loss = self.agent.replay()
else:
loss = None
if predicted_action == action:
correct_predictions += 1
episode_reward += reward
if loss is not None:
total_loss += loss
# Calculate metrics
accuracy = correct_predictions / (end_idx - start_idx) * 100
avg_loss = total_loss / (end_idx - start_idx) if end_idx > start_idx else 0
# Update training history
self.episode_rewards.append(episode_reward)
self.avg_rewards.append(self.agent.avg_reward)
self.losses.append(avg_loss)
self.epsilons.append(self.agent.epsilon)
# Log metrics
logger.info(f"Episode {episode} - Reward: {episode_reward:.2f}, Avg Reward: {self.agent.avg_reward:.2f}, "
f"Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%, Epsilon: {self.agent.epsilon:.4f}")
# Log to TensorBoard
self._log_to_tensorboard(episode, episode_reward, avg_loss, accuracy)
# Save model if improved
improved = episode_reward > self.best_reward
if improved:
self.best_reward = episode_reward
logger.info(f"New best reward: {self.best_reward:.2f}")
# Periodically save model
if episode % self.save_interval == 0 or improved:
self._save_model(final=False)
# Plot progress
if episode % self.plot_interval == 0:
self._plot_training_progress()
# Save final model
logger.info("Training completed.")
self._save_model(final=True)
def _log_to_tensorboard(self, episode, reward, loss, accuracy):
"""Log training metrics to TensorBoard"""
if self.tensorboard_writer:
self.tensorboard_writer.add_scalar('Train/Reward', reward, episode)
self.tensorboard_writer.add_scalar('Train/AvgReward', self.agent.avg_reward, episode)
self.tensorboard_writer.add_scalar('Train/Loss', loss, episode)
self.tensorboard_writer.add_scalar('Train/Accuracy', accuracy, episode)
self.tensorboard_writer.add_scalar('Train/Epsilon', self.agent.epsilon, episode)
def _save_model(self, final=False):
"""Save the DQN model"""
if final:
save_path = self.models_dir / f"dqn_agent_final_{timestamp}"
else:
save_path = self.models_dir / "dqn_agent_best"
self.agent.save(str(save_path))
logger.info(f"Model saved to {save_path}")
def _plot_training_progress(self):
"""Plot training progress metrics"""
if not self.episode_rewards:
logger.warning("No training data available for plotting yet")
return
plt.figure(figsize=(15, 10))
# Plot rewards
plt.subplot(2, 2, 1)
plt.plot(self.episode_rewards, label='Episode Reward')
plt.plot(self.avg_rewards, label='Avg Reward', linestyle='--')
plt.title('Rewards')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.legend()
# Plot losses
plt.subplot(2, 2, 2)
plt.plot(self.losses)
plt.title('Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
# Plot epsilon
plt.subplot(2, 2, 3)
plt.plot(self.epsilons)
plt.title('Exploration Rate (Epsilon)')
plt.xlabel('Episode')
plt.ylabel('Epsilon')
# Save plot
plots_dir = Path("plots")
plots_dir.mkdir(exist_ok=True)
plt.tight_layout()
plt.savefig(plots_dir / f"dqn_training_progress_{timestamp}.png")
plt.close()
def parse_args():
parser = argparse.ArgumentParser(description='DQN Training Session with Monitoring')
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading symbol')
parser.add_argument('--timeframes', type=str, default='1m,5m,15m', help='Comma-separated timeframes')
parser.add_argument('--window', type=int, default=24, help='Window size for state construction')
parser.add_argument('--batch-size', type=int, default=64, help='Batch size for training')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--plot-interval', type=int, default=5, help='Interval for plotting progress')
parser.add_argument('--save-interval', type=int, default=10, help='Interval for saving model')
parser.add_argument('--new-model', action='store_true', help='Start with a new model instead of loading existing')
return parser.parse_args()
def main():
args = parse_args()
# Force CPU training to avoid device mismatch errors
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['DISABLE_MIXED_PRECISION'] = '1'
# Create custom config based on arguments
custom_config = {
'market_data': {
'symbol': args.symbol,
'timeframes': args.timeframes.split(','),
'window_size': args.window
},
'training': {
'batch_size': args.batch_size,
'learning_rate': args.lr,
'save_interval': args.save_interval
},
'visualization': {
'plot_interval': args.plot_interval
},
'model': {
'new_model': args.new_model
},
'hardware': {
'device': 'cpu',
'mixed_precision': False
}
}
# Get configuration
config = train_config.get_config('reinforcement', custom_config)
# Save configuration for reference
config_dir = Path("configs")
config_dir.mkdir(exist_ok=True)
config_path = config_dir / f"dqn_training_config_{timestamp}.json"
train_config.save_config(config, str(config_path))
# Initialize and train
monitor = DQNTrainingMonitor(config)
monitor.initialize()
monitor.train(num_episodes=args.episodes)
logger.info(f"Training completed. Results saved to logs and plots directories.")
logger.info(f"To visualize training in TensorBoard, run: tensorboard --logdir={config['paths']['tensorboard_dir']}")
if __name__ == "__main__":
main()