415 lines
16 KiB
Python
415 lines
16 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
DQN Training Session with Monitoring
|
|
|
|
This script sets up and runs a DQN agent training session with progress monitoring.
|
|
It tracks key metrics like rewards, losses, and prediction accuracy, and
|
|
visualizes the agent's learning progress.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import time
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
import signal
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
# Add project root to path if needed
|
|
project_root = os.path.dirname(os.path.abspath(__file__))
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
# Import configurations
|
|
import train_config
|
|
|
|
# Import key components
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from dataprovider_realtime import MultiTimeframeDataInterface
|
|
|
|
# Configure logging
|
|
log_dir = Path("logs")
|
|
log_dir.mkdir(exist_ok=True)
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
log_file = log_dir / f"dqn_training_{timestamp}.log"
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_file),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger('dqn_training')
|
|
|
|
# Global variables for graceful shutdown
|
|
running = True
|
|
|
|
# Configure signal handler for graceful shutdown
|
|
def signal_handler(sig, frame):
|
|
global running
|
|
logger.info("Received interrupt signal. Finishing current episode and saving model...")
|
|
running = False
|
|
|
|
# Register signal handler
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
class DQNTrainingMonitor:
|
|
"""
|
|
Class to monitor DQN training progress and visualize results
|
|
"""
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.device = torch.device(config['hardware']['device'])
|
|
self.agent = None
|
|
self.data_interface = None
|
|
|
|
# Training stats
|
|
self.episode_rewards = []
|
|
self.avg_rewards = []
|
|
self.losses = []
|
|
self.epsilons = []
|
|
self.best_reward = -float('inf')
|
|
self.tensorboard_writer = None
|
|
|
|
# Paths
|
|
self.models_dir = Path(config['paths']['models_dir'])
|
|
self.models_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
# Metrics display intervals
|
|
self.plot_interval = config.get('visualization', {}).get('plot_interval', 5)
|
|
self.save_interval = config.get('training', {}).get('save_interval', 10)
|
|
|
|
def initialize(self):
|
|
"""Initialize the DQN agent and data interface"""
|
|
# Set up TensorBoard
|
|
tb_dir = Path(self.config['paths']['tensorboard_dir'])
|
|
tb_dir.mkdir(exist_ok=True, parents=True)
|
|
log_dir = tb_dir / f"dqn_{timestamp}"
|
|
self.tensorboard_writer = SummaryWriter(log_dir=str(log_dir))
|
|
logger.info(f"TensorBoard initialized at {log_dir}")
|
|
|
|
# Initialize data interface
|
|
symbol = self.config['market_data']['symbol']
|
|
timeframes = self.config['market_data']['timeframes']
|
|
window_size = self.config['market_data']['window_size']
|
|
|
|
logger.info(f"Initializing data interface for {symbol} with timeframes {timeframes}")
|
|
self.data_interface = MultiTimeframeDataInterface(
|
|
symbol=symbol,
|
|
timeframes=timeframes
|
|
)
|
|
|
|
# Get data for training
|
|
X_train_dict, _, _, _, _, _ = self.data_interface.prepare_training_data(
|
|
window_size=window_size,
|
|
refresh=True
|
|
)
|
|
|
|
if X_train_dict is None:
|
|
raise ValueError("Failed to load training data for DQN agent")
|
|
|
|
# Get feature count from the reference timeframe
|
|
reference_tf = min(
|
|
timeframes,
|
|
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
|
)
|
|
|
|
num_features = X_train_dict[reference_tf].shape[2]
|
|
logger.info(f"Using {num_features} features from timeframe {reference_tf}")
|
|
|
|
# Initialize DQN agent
|
|
state_size = num_features * window_size * len(timeframes)
|
|
action_size = 3 # Buy, Hold, Sell
|
|
|
|
logger.info(f"Initializing DQN agent with state size {state_size} and action size {action_size}")
|
|
self.agent = DQNAgent(
|
|
state_shape=(len(timeframes), window_size, num_features), # Multi-dimensional state shape
|
|
n_actions=action_size,
|
|
learning_rate=self.config['training']['learning_rate'],
|
|
batch_size=self.config['training']['batch_size'],
|
|
gamma=self.config.get('model', {}).get('gamma', 0.95),
|
|
epsilon=self.config.get('model', {}).get('epsilon_start', 1.0),
|
|
epsilon_min=self.config.get('model', {}).get('epsilon_min', 0.01),
|
|
epsilon_decay=self.config.get('model', {}).get('epsilon_decay', 0.995),
|
|
buffer_size=self.config.get('model', {}).get('memory_size', 10000),
|
|
device=self.device
|
|
)
|
|
|
|
# Load existing model if available
|
|
model_path = self.models_dir / "dqn_agent_best"
|
|
if os.path.exists(f"{model_path}_policy.pt") and not self.config.get('model', {}).get('new_model', False):
|
|
logger.info(f"Loading existing DQN model from {model_path}")
|
|
try:
|
|
self.agent.load(str(model_path))
|
|
logger.info("DQN model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {str(e)}")
|
|
logger.info("Starting with a new model instead")
|
|
else:
|
|
logger.info("Starting with a new model")
|
|
|
|
return True
|
|
|
|
def train(self, num_episodes=100):
|
|
"""Train the DQN agent for a specified number of episodes"""
|
|
if self.agent is None:
|
|
raise ValueError("Agent not initialized. Call initialize() first.")
|
|
|
|
logger.info(f"Starting DQN training for {num_episodes} episodes")
|
|
|
|
# Get training data
|
|
window_size = self.config['market_data']['window_size']
|
|
X_train_dict, y_train, _, _, _, _ = self.data_interface.prepare_training_data(
|
|
window_size=window_size,
|
|
refresh=True
|
|
)
|
|
|
|
# Prepare data for training
|
|
reference_tf = min(
|
|
self.config['market_data']['timeframes'],
|
|
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
|
)
|
|
|
|
# Convert data to flat states for RL
|
|
states = []
|
|
actions = []
|
|
|
|
# Find the minimum length across all timeframes to ensure consistent indexing
|
|
min_length = min(len(X_train_dict[tf]) for tf in self.config['market_data']['timeframes'])
|
|
logger.info(f"Using {min_length} samples from each timeframe for training")
|
|
|
|
# Only use indices that exist in all timeframes
|
|
for i in range(min_length):
|
|
state = []
|
|
for tf in self.config['market_data']['timeframes']:
|
|
state.extend(X_train_dict[tf][i].flatten())
|
|
states.append(np.array(state))
|
|
actions.append(np.argmax(y_train[i]))
|
|
|
|
logger.info(f"Prepared {len(states)} state-action pairs for training")
|
|
|
|
# Training loop
|
|
global running
|
|
for episode in range(1, num_episodes + 1):
|
|
if not running:
|
|
logger.info("Training interrupted. Saving final model.")
|
|
self._save_model(final=True)
|
|
break
|
|
|
|
episode_reward = 0
|
|
total_loss = 0
|
|
correct_predictions = 0
|
|
|
|
# Randomly sample start position (to prevent overfitting on sequence)
|
|
start_idx = np.random.randint(0, len(states) - 1000) if len(states) > 1000 else 0
|
|
end_idx = min(start_idx + 1000, len(states))
|
|
|
|
logger.info(f"Episode {episode}/{num_episodes} - Training on sequence from {start_idx} to {end_idx}")
|
|
|
|
# Training on sequence
|
|
for i in range(start_idx, end_idx - 1):
|
|
state = states[i]
|
|
action = actions[i]
|
|
next_state = states[i + 1]
|
|
|
|
# Get reward based on price movement
|
|
# Price is typically the close price (4th column in OHLCV data)
|
|
try:
|
|
# Assuming the last feature in each timeframe is the closing price
|
|
price_current = X_train_dict[reference_tf][i][-1, -1] # Last row, last column of current state
|
|
price_next = X_train_dict[reference_tf][i+1][-1, -1] # Last row, last column of next state
|
|
price_diff = price_next - price_current
|
|
except IndexError:
|
|
# Fallback if we're at the edge of our data
|
|
price_diff = 0
|
|
|
|
if action == 0: # Buy
|
|
reward = price_diff * 100 # Scale reward for better learning
|
|
elif action == 2: # Sell
|
|
reward = -price_diff * 100
|
|
else: # Hold
|
|
reward = abs(price_diff) * 10 if abs(price_diff) < 0.0001 else -abs(price_diff) * 50
|
|
|
|
# Train the agent with this experience
|
|
predicted_action = self.agent.act(state)
|
|
|
|
# Store experience in memory
|
|
done = (i == end_idx - 2) # Mark as done if it's the last step
|
|
self.agent.remember(state, action, reward, next_state, done)
|
|
|
|
# Periodically replay from memory
|
|
if i % 10 == 0: # Replay every 10 steps
|
|
loss = self.agent.replay()
|
|
else:
|
|
loss = None
|
|
|
|
if predicted_action == action:
|
|
correct_predictions += 1
|
|
|
|
episode_reward += reward
|
|
if loss is not None:
|
|
total_loss += loss
|
|
|
|
# Calculate metrics
|
|
accuracy = correct_predictions / (end_idx - start_idx) * 100
|
|
avg_loss = total_loss / (end_idx - start_idx) if end_idx > start_idx else 0
|
|
|
|
# Update training history
|
|
self.episode_rewards.append(episode_reward)
|
|
self.avg_rewards.append(self.agent.avg_reward)
|
|
self.losses.append(avg_loss)
|
|
self.epsilons.append(self.agent.epsilon)
|
|
|
|
# Log metrics
|
|
logger.info(f"Episode {episode} - Reward: {episode_reward:.2f}, Avg Reward: {self.agent.avg_reward:.2f}, "
|
|
f"Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%, Epsilon: {self.agent.epsilon:.4f}")
|
|
|
|
# Log to TensorBoard
|
|
self._log_to_tensorboard(episode, episode_reward, avg_loss, accuracy)
|
|
|
|
# Save model if improved
|
|
improved = episode_reward > self.best_reward
|
|
if improved:
|
|
self.best_reward = episode_reward
|
|
logger.info(f"New best reward: {self.best_reward:.2f}")
|
|
|
|
# Periodically save model
|
|
if episode % self.save_interval == 0 or improved:
|
|
self._save_model(final=False)
|
|
|
|
# Plot progress
|
|
if episode % self.plot_interval == 0:
|
|
self._plot_training_progress()
|
|
|
|
# Save final model
|
|
logger.info("Training completed.")
|
|
self._save_model(final=True)
|
|
|
|
def _log_to_tensorboard(self, episode, reward, loss, accuracy):
|
|
"""Log training metrics to TensorBoard"""
|
|
if self.tensorboard_writer:
|
|
self.tensorboard_writer.add_scalar('Train/Reward', reward, episode)
|
|
self.tensorboard_writer.add_scalar('Train/AvgReward', self.agent.avg_reward, episode)
|
|
self.tensorboard_writer.add_scalar('Train/Loss', loss, episode)
|
|
self.tensorboard_writer.add_scalar('Train/Accuracy', accuracy, episode)
|
|
self.tensorboard_writer.add_scalar('Train/Epsilon', self.agent.epsilon, episode)
|
|
|
|
def _save_model(self, final=False):
|
|
"""Save the DQN model"""
|
|
if final:
|
|
save_path = self.models_dir / f"dqn_agent_final_{timestamp}"
|
|
else:
|
|
save_path = self.models_dir / "dqn_agent_best"
|
|
|
|
self.agent.save(str(save_path))
|
|
logger.info(f"Model saved to {save_path}")
|
|
|
|
def _plot_training_progress(self):
|
|
"""Plot training progress metrics"""
|
|
if not self.episode_rewards:
|
|
logger.warning("No training data available for plotting yet")
|
|
return
|
|
|
|
plt.figure(figsize=(15, 10))
|
|
|
|
# Plot rewards
|
|
plt.subplot(2, 2, 1)
|
|
plt.plot(self.episode_rewards, label='Episode Reward')
|
|
plt.plot(self.avg_rewards, label='Avg Reward', linestyle='--')
|
|
plt.title('Rewards')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Reward')
|
|
plt.legend()
|
|
|
|
# Plot losses
|
|
plt.subplot(2, 2, 2)
|
|
plt.plot(self.losses)
|
|
plt.title('Loss')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Loss')
|
|
|
|
# Plot epsilon
|
|
plt.subplot(2, 2, 3)
|
|
plt.plot(self.epsilons)
|
|
plt.title('Exploration Rate (Epsilon)')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Epsilon')
|
|
|
|
# Save plot
|
|
plots_dir = Path("plots")
|
|
plots_dir.mkdir(exist_ok=True)
|
|
plt.tight_layout()
|
|
plt.savefig(plots_dir / f"dqn_training_progress_{timestamp}.png")
|
|
plt.close()
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='DQN Training Session with Monitoring')
|
|
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
|
|
parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading symbol')
|
|
parser.add_argument('--timeframes', type=str, default='1m,5m,15m', help='Comma-separated timeframes')
|
|
parser.add_argument('--window', type=int, default=24, help='Window size for state construction')
|
|
parser.add_argument('--batch-size', type=int, default=64, help='Batch size for training')
|
|
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
|
|
parser.add_argument('--plot-interval', type=int, default=5, help='Interval for plotting progress')
|
|
parser.add_argument('--save-interval', type=int, default=10, help='Interval for saving model')
|
|
parser.add_argument('--new-model', action='store_true', help='Start with a new model instead of loading existing')
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Force CPU training to avoid device mismatch errors
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
|
os.environ['DISABLE_MIXED_PRECISION'] = '1'
|
|
|
|
# Create custom config based on arguments
|
|
custom_config = {
|
|
'market_data': {
|
|
'symbol': args.symbol,
|
|
'timeframes': args.timeframes.split(','),
|
|
'window_size': args.window
|
|
},
|
|
'training': {
|
|
'batch_size': args.batch_size,
|
|
'learning_rate': args.lr,
|
|
'save_interval': args.save_interval
|
|
},
|
|
'visualization': {
|
|
'plot_interval': args.plot_interval
|
|
},
|
|
'model': {
|
|
'new_model': args.new_model
|
|
},
|
|
'hardware': {
|
|
'device': 'cpu',
|
|
'mixed_precision': False
|
|
}
|
|
}
|
|
|
|
# Get configuration
|
|
config = train_config.get_config('reinforcement', custom_config)
|
|
|
|
# Save configuration for reference
|
|
config_dir = Path("configs")
|
|
config_dir.mkdir(exist_ok=True)
|
|
config_path = config_dir / f"dqn_training_config_{timestamp}.json"
|
|
train_config.save_config(config, str(config_path))
|
|
|
|
# Initialize and train
|
|
monitor = DQNTrainingMonitor(config)
|
|
monitor.initialize()
|
|
monitor.train(num_episodes=args.episodes)
|
|
|
|
logger.info(f"Training completed. Results saved to logs and plots directories.")
|
|
logger.info(f"To visualize training in TensorBoard, run: tensorboard --logdir={config['paths']['tensorboard_dir']}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |