misc
This commit is contained in:
@ -18,6 +18,7 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import deque
|
||||
import random
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# Add project imports
|
||||
import sys
|
||||
@ -75,8 +76,23 @@ class RLTrainer:
|
||||
self.win_rates = []
|
||||
self.avg_rewards = []
|
||||
|
||||
# TensorBoard setup
|
||||
self.setup_tensorboard()
|
||||
|
||||
logger.info(f"RLTrainer initialized for symbols: {self.symbols}")
|
||||
|
||||
def setup_tensorboard(self):
|
||||
"""Setup TensorBoard logging"""
|
||||
# Create tensorboard logs directory
|
||||
log_dir = Path("runs") / f"rl_training_{int(time.time())}"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(log_dir=str(log_dir))
|
||||
self.tensorboard_dir = log_dir
|
||||
|
||||
logger.info(f"TensorBoard logging to: {log_dir}")
|
||||
logger.info(f"Run: tensorboard --logdir=runs")
|
||||
|
||||
def setup_environment_and_agent(self) -> Tuple[ScalpingEnvironment, ScalpingRLAgent]:
|
||||
"""Setup trading environment and RL agent"""
|
||||
logger.info("Setting up environment and agent...")
|
||||
@ -443,6 +459,29 @@ class RLTrainer:
|
||||
|
||||
plt.show()
|
||||
|
||||
def log_episode_metrics(self, episode: int, metrics: Dict):
|
||||
"""Log episode metrics to TensorBoard"""
|
||||
# Main performance metrics
|
||||
self.writer.add_scalar('Episode/TotalReward', metrics['total_reward'], episode)
|
||||
self.writer.add_scalar('Episode/FinalBalance', metrics['final_balance'], episode)
|
||||
self.writer.add_scalar('Episode/TotalReturn', metrics['total_return'], episode)
|
||||
self.writer.add_scalar('Episode/Steps', metrics['steps'], episode)
|
||||
|
||||
# Trading metrics
|
||||
self.writer.add_scalar('Trading/TotalTrades', metrics['total_trades'], episode)
|
||||
self.writer.add_scalar('Trading/WinRate', metrics['win_rate'], episode)
|
||||
self.writer.add_scalar('Trading/ProfitFactor', metrics.get('profit_factor', 0), episode)
|
||||
self.writer.add_scalar('Trading/MaxDrawdown', metrics.get('max_drawdown', 0), episode)
|
||||
|
||||
# Agent metrics
|
||||
self.writer.add_scalar('Agent/Epsilon', metrics['epsilon'], episode)
|
||||
self.writer.add_scalar('Agent/LearningRate', metrics.get('learning_rate', self.learning_rate), episode)
|
||||
self.writer.add_scalar('Agent/MemorySize', metrics.get('memory_size', 0), episode)
|
||||
|
||||
# Loss metrics (if available)
|
||||
if 'loss' in metrics:
|
||||
self.writer.add_scalar('Agent/Loss', metrics['loss'], episode)
|
||||
|
||||
class HybridTrainer:
|
||||
"""
|
||||
Hybrid training pipeline combining CNN and RL
|
||||
|
Reference in New Issue
Block a user