gogo2/training/enhanced_rl_trainer.py
2025-05-28 23:42:06 +03:00

821 lines
34 KiB
Python

"""
Enhanced RL Trainer with Continuous Learning
This module implements sophisticated RL training with:
- Prioritized experience replay
- Market regime adaptation
- Continuous learning from trading outcomes
- Performance tracking and visualization
"""
import asyncio
import logging
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Union
import matplotlib.pyplot as plt
from pathlib import Path
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
from models import RLAgentInterface
import models
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
from training.williams_market_structure import WilliamsMarketStructure
from training.cnn_rl_bridge import CNNRLBridge
logger = logging.getLogger(__name__)
# Experience tuple for replay buffer
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'priority'])
class PrioritizedReplayBuffer:
"""Prioritized experience replay buffer for RL training"""
def __init__(self, capacity: int = 10000, alpha: float = 0.6):
"""
Initialize prioritized replay buffer
Args:
capacity: Maximum number of experiences to store
alpha: Priority exponent (0 = uniform, 1 = fully prioritized)
"""
self.capacity = capacity
self.alpha = alpha
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.position = 0
self.size = 0
def add(self, experience: Experience):
"""Add experience to buffer with priority"""
max_priority = self.priorities[:self.size].max() if self.size > 0 else 1.0
if self.size < self.capacity:
self.buffer.append(experience)
self.size += 1
else:
self.buffer[self.position] = experience
self.priorities[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[List[Experience], np.ndarray, np.ndarray]:
"""Sample batch with prioritized sampling"""
if self.size == 0:
return [], np.array([]), np.array([])
# Calculate sampling probabilities
priorities = self.priorities[:self.size] ** self.alpha
probabilities = priorities / priorities.sum()
# Sample indices
indices = np.random.choice(self.size, batch_size, p=probabilities)
experiences = [self.buffer[i] for i in indices]
# Calculate importance sampling weights
weights = (self.size * probabilities[indices]) ** (-beta)
weights = weights / weights.max() # Normalize
return experiences, indices, weights
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
"""Update priorities for sampled experiences"""
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority + 1e-6 # Small epsilon to avoid zero priority
def __len__(self):
return self.size
class EnhancedDQNAgent(nn.Module, RLAgentInterface):
"""Enhanced DQN agent with market environment adaptation"""
def __init__(self, config: Dict[str, Any]):
nn.Module.__init__(self)
RLAgentInterface.__init__(self, config)
# Network architecture
self.state_size = config.get('state_size', 100)
self.action_space = config.get('action_space', 3)
self.hidden_size = config.get('hidden_size', 256)
# Build networks
self._build_networks()
# Training parameters
self.learning_rate = config.get('learning_rate', 0.0001)
self.gamma = config.get('gamma', 0.99)
self.epsilon = config.get('epsilon', 1.0)
self.epsilon_decay = config.get('epsilon_decay', 0.995)
self.epsilon_min = config.get('epsilon_min', 0.01)
self.target_update_freq = config.get('target_update_freq', 1000)
# Initialize device and optimizer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
# Experience replay
self.replay_buffer = PrioritizedReplayBuffer(config.get('buffer_size', 10000))
self.batch_size = config.get('batch_size', 64)
# Market adaptation
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Training statistics
self.training_steps = 0
self.losses = []
self.rewards = []
self.epsilon_history = []
logger.info(f"Enhanced DQN agent initialized with state size: {self.state_size}")
def _build_networks(self):
"""Build main and target networks"""
# Main network
self.main_network = nn.Sequential(
nn.Linear(self.state_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2)
)
# Dueling network heads
self.value_head = nn.Linear(128, 1)
self.advantage_head = nn.Linear(128, self.action_space)
# Target network (copy of main network)
self.target_network = nn.Sequential(
nn.Linear(self.state_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2)
)
self.target_value_head = nn.Linear(128, 1)
self.target_advantage_head = nn.Linear(128, self.action_space)
# Initialize target network with same weights
self._update_target_network()
def forward(self, state, target: bool = False):
"""Forward pass through the network"""
if target:
features = self.target_network(state)
value = self.target_value_head(features)
advantage = self.target_advantage_head(features)
else:
features = self.main_network(state)
value = self.value_head(features)
advantage = self.advantage_head(features)
# Dueling architecture: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values
def act(self, state: np.ndarray) -> int:
"""Choose action using epsilon-greedy policy"""
if random.random() < self.epsilon:
return random.randint(0, self.action_space - 1)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.forward(state_tensor)
return q_values.argmax().item()
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
"""Choose action with confidence score adapted to market regime"""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.forward(state_tensor)
# Convert Q-values to probabilities
action_probs = torch.softmax(q_values, dim=1)
action = q_values.argmax().item()
base_confidence = action_probs[0, action].item()
# Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
return action, adapted_confidence
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool):
"""Store experience in replay buffer"""
# Calculate TD error for priority
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
current_q = self.forward(state_tensor)[0, action]
next_q = self.forward(next_state_tensor, target=True).max(1)[0]
target_q = reward + (self.gamma * next_q * (1 - done))
td_error = abs(current_q.item() - target_q.item())
experience = Experience(state, action, reward, next_state, done, td_error)
self.replay_buffer.add(experience)
def replay(self) -> Optional[float]:
"""Train the network on a batch of experiences"""
if len(self.replay_buffer) < self.batch_size:
return None
# Sample batch
experiences, indices, weights = self.replay_buffer.sample(self.batch_size)
if not experiences:
return None
# Convert to tensors
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
dones = torch.BoolTensor([e.done for e in experiences]).to(self.device)
weights_tensor = torch.FloatTensor(weights).to(self.device)
# Current Q-values
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
# Target Q-values (Double DQN)
with torch.no_grad():
# Use main network to select actions
next_actions = self.forward(next_states).argmax(1)
# Use target network to evaluate actions
next_q_values = self.forward(next_states, target=True).gather(1, next_actions.unsqueeze(1))
target_q_values = rewards.unsqueeze(1) + (self.gamma * next_q_values * ~dones.unsqueeze(1))
# Calculate weighted loss
td_errors = target_q_values - current_q_values
loss = (weights_tensor * (td_errors ** 2)).mean()
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
self.optimizer.step()
# Update priorities
new_priorities = torch.abs(td_errors).detach().cpu().numpy().flatten()
self.replay_buffer.update_priorities(indices, new_priorities)
# Update target network
self.training_steps += 1
if self.training_steps % self.target_update_freq == 0:
self._update_target_network()
# Decay epsilon
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
# Track statistics
self.losses.append(loss.item())
self.epsilon_history.append(self.epsilon)
return loss.item()
def _update_target_network(self):
"""Update target network with main network weights"""
self.target_network.load_state_dict(self.main_network.state_dict())
self.target_value_head.load_state_dict(self.value_head.state_dict())
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
"""Predict action probabilities and confidence (required by ModelInterface)"""
action, confidence = self.act_with_confidence(features)
# Convert action to probabilities
action_probs = np.zeros(self.action_space)
action_probs[action] = 1.0
return action_probs, confidence
def get_memory_usage(self) -> int:
"""Get memory usage in MB"""
if torch.cuda.is_available():
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
else:
param_count = sum(p.numel() for p in self.parameters())
buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate
return (param_count * 4 + buffer_size) // (1024 * 1024)
class EnhancedRLTrainer:
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
"""Initialize enhanced RL trainer with comprehensive state building"""
self.config = config or get_config()
self.orchestrator = orchestrator
# Initialize comprehensive state builder (replaces mock code)
self.state_builder = EnhancedRLStateBuilder(self.config)
self.williams_structure = WilliamsMarketStructure()
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
# Enhanced RL agents with much larger state space
self.agents = {}
self.initialize_agents()
# Training configuration
self.symbols = self.config.symbols
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
self.save_dir.mkdir(parents=True, exist_ok=True)
# Performance tracking
self.training_metrics = {
'total_episodes': 0,
'total_rewards': {symbol: [] for symbol in self.symbols},
'losses': {symbol: [] for symbol in self.symbols},
'epsilon_values': {symbol: [] for symbol in self.symbols}
}
self.performance_history = {symbol: [] for symbol in self.symbols}
# Real-time learning parameters
self.learning_active = False
self.experience_buffer_size = 1000
self.min_experiences_for_training = 100
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
logger.info(f"Symbols: {self.symbols}")
def initialize_agents(self):
"""Initialize RL agents with enhanced state size"""
for symbol in self.symbols:
agent_config = {
'state_size': self.state_builder.total_state_size, # ~13,400 features
'action_space': 3, # BUY, SELL, HOLD
'hidden_size': 1024, # Larger hidden layers for complex state
'learning_rate': 0.0001,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_min': 0.01,
'buffer_size': 50000, # Larger replay buffer
'batch_size': 128,
'target_update_freq': 1000
}
self.agents[symbol] = EnhancedDQNAgent(agent_config)
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
async def continuous_learning_loop(self):
"""Main continuous learning loop"""
logger.info("Starting continuous RL learning loop")
while True:
try:
# Train agents with recent experiences
await self._train_all_agents()
# Evaluate recent actions
if self.orchestrator:
await self.orchestrator.evaluate_actions_with_rl()
# Adapt to market regime changes
await self._adapt_to_market_changes()
# Update performance metrics
self._update_performance_metrics()
# Save models periodically
if self.training_metrics['total_episodes'] % 100 == 0:
self._save_all_models()
# Wait before next training cycle
await asyncio.sleep(3600) # Train every hour
except Exception as e:
logger.error(f"Error in continuous learning loop: {e}")
await asyncio.sleep(60) # Wait 1 minute on error
async def _train_all_agents(self):
"""Train all RL agents with their experiences"""
for symbol, agent in self.agents.items():
try:
if len(agent.replay_buffer) >= self.min_experiences_for_training:
# Train for multiple steps
losses = []
for _ in range(10): # Train 10 steps per cycle
loss = agent.replay()
if loss is not None:
losses.append(loss)
if losses:
avg_loss = np.mean(losses)
self.training_metrics['losses'][symbol].append(avg_loss)
self.training_metrics['epsilon_values'][symbol].append(agent.epsilon)
logger.info(f"Trained {symbol} RL agent: Loss={avg_loss:.4f}, Epsilon={agent.epsilon:.4f}")
except Exception as e:
logger.error(f"Error training {symbol} agent: {e}")
async def _adapt_to_market_changes(self):
"""Adapt agents to market regime changes"""
if not self.orchestrator:
return
for symbol in self.symbols:
try:
# Get recent market states
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
if len(recent_states) < 5:
continue
# Analyze regime stability
regimes = [state.market_regime for state in recent_states]
regime_stability = len(set(regimes)) / len(regimes) # Lower = more stable
# Adjust learning parameters based on stability
agent = self.agents[symbol]
if regime_stability < 0.3: # Stable regime
agent.epsilon *= 0.99 # Faster epsilon decay
elif regime_stability > 0.7: # Unstable regime
agent.epsilon = min(agent.epsilon * 1.01, 0.5) # Increase exploration
logger.debug(f"{symbol} regime stability: {regime_stability:.3f}, epsilon: {agent.epsilon:.3f}")
except Exception as e:
logger.error(f"Error adapting {symbol} to market changes: {e}")
def add_trading_experience(self, symbol: str, action: TradingAction,
initial_state: MarketState, final_state: MarketState,
reward: float):
"""Add trading experience to the appropriate agent"""
if symbol not in self.agents:
logger.warning(f"No agent for symbol {symbol}")
return
try:
# Convert market states to RL state vectors
initial_rl_state = self._market_state_to_rl_state(initial_state)
final_rl_state = self._market_state_to_rl_state(final_state)
# Convert action to RL action index
action_mapping = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
action_idx = action_mapping.get(action.action, 1)
# Store experience
agent = self.agents[symbol]
agent.remember(
state=initial_rl_state,
action=action_idx,
reward=reward,
next_state=final_rl_state,
done=False
)
# Track reward
self.training_metrics['total_rewards'][symbol].append(reward)
logger.debug(f"Added experience for {symbol}: action={action.action}, reward={reward:.4f}")
except Exception as e:
logger.error(f"Error adding experience for {symbol}: {e}")
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
"""Convert market state to comprehensive RL state vector using real data"""
try:
# Extract data from market state and orchestrator
if not self.orchestrator:
logger.warning("No orchestrator available for comprehensive state building")
return self._fallback_state_conversion(market_state)
# Get real tick data from orchestrator's data provider
symbol = market_state.symbol
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
# Get multi-timeframe OHLCV data
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
# Get CNN features if available
cnn_hidden_features = None
cnn_predictions = None
if self.cnn_rl_bridge:
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
if cnn_data:
cnn_hidden_features = cnn_data.get('hidden_features', {})
cnn_predictions = cnn_data.get('predictions', {})
# Get pivot point data
pivot_data = self._calculate_pivot_points(eth_ohlcv)
# Build comprehensive state using enhanced state builder
comprehensive_state = self.state_builder.build_rl_state(
eth_ticks=eth_ticks,
eth_ohlcv=eth_ohlcv,
btc_ohlcv=btc_ohlcv,
cnn_hidden_features=cnn_hidden_features,
cnn_predictions=cnn_predictions,
pivot_data=pivot_data
)
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
return comprehensive_state
except Exception as e:
logger.error(f"Error building comprehensive RL state: {e}")
return self._fallback_state_conversion(market_state)
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
"""Get recent tick data from orchestrator's data provider"""
try:
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
# Get recent ticks from data provider
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
# Convert to required format
tick_data = []
for tick in recent_ticks[-300:]: # Last 300 ticks max
tick_data.append({
'timestamp': tick.timestamp,
'price': tick.price,
'volume': tick.volume,
'quantity': getattr(tick, 'quantity', tick.volume),
'side': getattr(tick, 'side', 'unknown'),
'trade_id': getattr(tick, 'trade_id', 'unknown')
})
return tick_data
return []
except Exception as e:
logger.warning(f"Error getting tick data for {symbol}: {e}")
return []
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
"""Get multi-timeframe OHLCV data"""
try:
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
ohlcv_data = {}
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
try:
# Get historical data for timeframe
df = self.orchestrator.data_provider.get_historical_data(
symbol=symbol,
timeframe=tf,
limit=300,
refresh=True
)
if df is not None and not df.empty:
# Convert to list of dictionaries
bars = []
for _, row in df.tail(300).iterrows():
bar = {
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
'open': float(row.get('open', 0)),
'high': float(row.get('high', 0)),
'low': float(row.get('low', 0)),
'close': float(row.get('close', 0)),
'volume': float(row.get('volume', 0))
}
bars.append(bar)
ohlcv_data[tf] = bars
else:
ohlcv_data[tf] = []
except Exception as e:
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
ohlcv_data[tf] = []
return ohlcv_data
return {}
except Exception as e:
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
return {}
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
"""Calculate Williams pivot points from OHLCV data"""
try:
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
# Convert to numpy array for Williams calculation
bars = eth_ohlcv['1m']
if len(bars) >= 50: # Need minimum data for pivot calculation
ohlc_array = np.array([
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
for bar in bars[-200:] # Last 200 bars
])
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
return pivot_data
return {}
except Exception as e:
logger.warning(f"Error calculating pivot points: {e}")
return {}
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
"""Fallback to basic state conversion if comprehensive state building fails"""
logger.warning("Using fallback state conversion - limited features")
state_components = [
market_state.volatility,
market_state.volume,
market_state.trend_strength
]
# Add price features
for timeframe in sorted(market_state.prices.keys()):
state_components.append(market_state.prices[timeframe])
# Pad to match expected state size
expected_size = self.state_builder.total_state_size
if len(state_components) < expected_size:
state_components.extend([0.0] * (expected_size - len(state_components)))
else:
state_components = state_components[:expected_size]
return np.array(state_components, dtype=np.float32)
def _update_performance_metrics(self):
"""Update performance tracking metrics"""
self.training_metrics['total_episodes'] += 1
# Calculate recent performance for each agent
for symbol, agent in self.agents.items():
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:] # Last 100 rewards
if recent_rewards:
avg_reward = np.mean(recent_rewards)
self.performance_history[symbol].append({
'timestamp': datetime.now(),
'avg_reward': avg_reward,
'epsilon': agent.epsilon,
'experiences': len(agent.replay_buffer)
})
def _save_all_models(self):
"""Save all RL models"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
for symbol, agent in self.agents.items():
filename = f"rl_agent_{symbol}_{timestamp}.pt"
filepath = self.save_dir / filename
torch.save({
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': agent.optimizer.state_dict(),
'config': self.config.rl,
'training_metrics': self.training_metrics,
'symbol': symbol,
'epsilon': agent.epsilon,
'training_steps': agent.training_steps
}, filepath)
logger.info(f"Saved {symbol} RL agent to {filepath}")
def load_models(self, timestamp: str = None):
"""Load RL models from files"""
if timestamp is None:
# Find most recent models
model_files = list(self.save_dir.glob("rl_agent_*.pt"))
if not model_files:
logger.warning("No saved RL models found")
return False
# Group by timestamp and get most recent
timestamps = set(f.stem.split('_')[-2] + '_' + f.stem.split('_')[-1] for f in model_files)
timestamp = max(timestamps)
loaded_count = 0
for symbol in self.symbols:
filename = f"rl_agent_{symbol}_{timestamp}.pt"
filepath = self.save_dir / filename
if filepath.exists():
try:
checkpoint = torch.load(filepath, map_location=self.agents[symbol].device)
self.agents[symbol].load_state_dict(checkpoint['model_state_dict'])
self.agents[symbol].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.agents[symbol].epsilon = checkpoint.get('epsilon', 0.1)
self.agents[symbol].training_steps = checkpoint.get('training_steps', 0)
logger.info(f"Loaded {symbol} RL agent from {filepath}")
loaded_count += 1
except Exception as e:
logger.error(f"Error loading {symbol} RL agent: {e}")
return loaded_count > 0
def get_performance_report(self) -> Dict[str, Any]:
"""Generate performance report for all agents"""
report = {
'total_episodes': self.training_metrics['total_episodes'],
'agents': {}
}
for symbol, agent in self.agents.items():
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:]
recent_losses = self.training_metrics['losses'][symbol][-10:]
agent_report = {
'symbol': symbol,
'epsilon': agent.epsilon,
'training_steps': agent.training_steps,
'experiences_stored': len(agent.replay_buffer),
'memory_usage_mb': agent.get_memory_usage(),
'avg_recent_reward': np.mean(recent_rewards) if recent_rewards else 0.0,
'avg_recent_loss': np.mean(recent_losses) if recent_losses else 0.0,
'total_rewards': len(self.training_metrics['total_rewards'][symbol])
}
report['agents'][symbol] = agent_report
return report
def plot_training_metrics(self):
"""Plot training metrics for all agents"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Enhanced RL Training Metrics')
symbols = list(self.agents.keys())
colors = ['blue', 'red', 'green', 'orange'][:len(symbols)]
# Rewards plot
for i, symbol in enumerate(symbols):
rewards = self.training_metrics['total_rewards'][symbol]
if rewards:
# Moving average of rewards
window = min(100, len(rewards))
if len(rewards) >= window:
moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
axes[0, 0].plot(moving_avg, label=f'{symbol}', color=colors[i])
axes[0, 0].set_title('Average Rewards (Moving Average)')
axes[0, 0].set_xlabel('Episodes')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].legend()
# Losses plot
for i, symbol in enumerate(symbols):
losses = self.training_metrics['losses'][symbol]
if losses:
axes[0, 1].plot(losses, label=f'{symbol}', color=colors[i])
axes[0, 1].set_title('Training Losses')
axes[0, 1].set_xlabel('Training Steps')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
# Epsilon values
for i, symbol in enumerate(symbols):
epsilon_values = self.training_metrics['epsilon_values'][symbol]
if epsilon_values:
axes[1, 0].plot(epsilon_values, label=f'{symbol}', color=colors[i])
axes[1, 0].set_title('Exploration Rate (Epsilon)')
axes[1, 0].set_xlabel('Training Steps')
axes[1, 0].set_ylabel('Epsilon')
axes[1, 0].legend()
# Experience buffer sizes
buffer_sizes = [len(agent.replay_buffer) for agent in self.agents.values()]
axes[1, 1].bar(symbols, buffer_sizes, color=colors[:len(symbols)])
axes[1, 1].set_title('Experience Buffer Sizes')
axes[1, 1].set_ylabel('Number of Experiences')
plt.tight_layout()
plt.savefig(self.save_dir / 'rl_training_metrics.png', dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"RL training plots saved to {self.save_dir / 'rl_training_metrics.png'}")
def get_agents(self) -> Dict[str, EnhancedDQNAgent]:
"""Get all RL agents"""
return self.agents