""" RL Training Pipeline with Comprehensive Experience Storage and Replay This module implements a robust RL training pipeline that: 1. Stores all training experiences with profitability metrics 2. Implements profit-weighted experience replay 3. Tracks gradient information for each training step 4. Enables retraining on most profitable trading sequences 5. Maintains comprehensive trading episode analysis """ import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, field import json import pickle from collections import deque import threading import random from .training_data_collector import get_training_data_collector logger = logging.getLogger(__name__) @dataclass class RLExperience: """Single RL experience with complete state-action-reward information""" experience_id: str timestamp: datetime episode_id: str # Core RL components state: np.ndarray action: int # 0=SELL, 1=HOLD, 2=BUY reward: float next_state: np.ndarray done: bool # Extended state information market_context: Dict[str, Any] cnn_predictions: Optional[Dict[str, Any]] = None confidence_score: float = 0.0 # Actual trading outcome actual_profit: Optional[float] = None actual_holding_time: Optional[timedelta] = None optimal_action: Optional[int] = None # Experience value for replay experience_value: float = 0.0 profitability_score: float = 0.0 learning_priority: float = 0.0 # Training metadata times_trained: int = 0 last_trained: Optional[datetime] = None class ProfitWeightedExperienceBuffer: """Experience buffer with profit-weighted sampling for replay""" def __init__(self, max_size: int = 100000): self.max_size = max_size self.experiences: Dict[str, RLExperience] = {} self.experience_order: deque = deque(maxlen=max_size) self.profitable_experiences: List[str] = [] self.total_experiences = 0 self.total_profitable = 0 def add_experience(self, experience: RLExperience): """Add experience to buffer""" try: self.experiences[experience.experience_id] = experience self.experience_order.append(experience.experience_id) if experience.actual_profit is not None and experience.actual_profit > 0: self.profitable_experiences.append(experience.experience_id) self.total_profitable += 1 # Remove oldest if buffer is full if len(self.experiences) > self.max_size: oldest_id = self.experience_order[0] if oldest_id in self.experiences: del self.experiences[oldest_id] if oldest_id in self.profitable_experiences: self.profitable_experiences.remove(oldest_id) self.total_experiences += 1 except Exception as e: logger.error(f"Error adding experience to buffer: {e}") def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]: """Sample batch with profit-weighted prioritization""" try: if len(self.experiences) < batch_size: return list(self.experiences.values()) if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2: # Sample mix of profitable and all experiences profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences)) remaining_sample_size = batch_size - profitable_sample_size profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size) all_ids = list(self.experiences.keys()) remaining_ids = random.sample(all_ids, remaining_sample_size) sampled_ids = profitable_ids + remaining_ids else: # Random sampling from all experiences all_ids = list(self.experiences.keys()) sampled_ids = random.sample(all_ids, batch_size) sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids] # Update training counts for experience in sampled_experiences: experience.times_trained += 1 experience.last_trained = datetime.now() return sampled_experiences except Exception as e: logger.error(f"Error sampling batch: {e}") return list(self.experiences.values())[:batch_size] def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]: """Get most profitable experiences for targeted training""" try: profitable_experiences = [ self.experiences[exp_id] for exp_id in self.profitable_experiences if exp_id in self.experiences ] profitable_experiences.sort( key=lambda x: x.actual_profit if x.actual_profit else 0, reverse=True ) return profitable_experiences[:limit] except Exception as e: logger.error(f"Error getting profitable experiences: {e}") return [] class RLTradingAgent(nn.Module): """RL Trading Agent with comprehensive state processing""" def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512): super(RLTradingAgent, self).__init__() self.state_dim = state_dim self.action_dim = action_dim self.hidden_dim = hidden_dim # State processing network self.state_processor = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.ReLU() ) # Q-value network self.q_network = nn.Sequential( nn.Linear(hidden_dim // 2, hidden_dim // 4), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim // 4, action_dim) ) # Policy network self.policy_network = nn.Sequential( nn.Linear(hidden_dim // 2, hidden_dim // 4), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim // 4, action_dim), nn.Softmax(dim=-1) ) # Value network self.value_network = nn.Sequential( nn.Linear(hidden_dim // 2, hidden_dim // 4), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim // 4, 1) ) def forward(self, state): """Forward pass through the agent""" processed_state = self.state_processor(state) q_values = self.q_network(processed_state) policy_probs = self.policy_network(processed_state) state_value = self.value_network(processed_state) return { 'q_values': q_values, 'policy_probs': policy_probs, 'state_value': state_value, 'processed_state': processed_state } def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]: """Select action using epsilon-greedy policy""" self.eval() with torch.no_grad(): if isinstance(state, np.ndarray): state = torch.from_numpy(state).float().unsqueeze(0) outputs = self.forward(state) if random.random() < epsilon: action = random.randint(0, self.action_dim - 1) confidence = 0.33 else: q_values = outputs['q_values'] action = torch.argmax(q_values, dim=1).item() q_softmax = F.softmax(q_values, dim=1) confidence = torch.max(q_softmax).item() return action, confidence @dataclass class RLTrainingStep: """Single RL training step with backpropagation data""" step_id: str timestamp: datetime batch_experiences: List[str] # Training data total_loss: float q_loss: float policy_loss: float # Gradients gradients: Dict[str, torch.Tensor] gradient_norms: Dict[str, float] # Metadata learning_rate: float = 0.001 batch_size: int = 32 # Performance batch_profitability: float = 0.0 correct_actions: int = 0 total_actions: int = 0 step_value: float = 0.0 @dataclass class RLTrainingSession: """Complete RL training session""" session_id: str start_timestamp: datetime end_timestamp: Optional[datetime] = None training_mode: str = 'experience_replay' symbol: str = '' training_steps: List[RLTrainingStep] = field(default_factory=list) total_steps: int = 0 average_loss: float = 0.0 best_loss: float = float('inf') profitable_actions: int = 0 total_actions: int = 0 profitability_rate: float = 0.0 session_value: float = 0.0 class RLTrainer: """RL trainer with comprehensive experience storage and replay""" def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"): self.agent = agent.to(device) self.device = device self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(parents=True, exist_ok=True) self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001) self.experience_buffer = ProfitWeightedExperienceBuffer() self.data_collector = get_training_data_collector() self.training_sessions: List[RLTrainingSession] = [] self.current_session: Optional[RLTrainingSession] = None self.gamma = 0.99 self.training_stats = { 'total_sessions': 0, 'total_steps': 0, 'total_experiences': 0, 'profitable_actions': 0, 'total_actions': 0, 'average_reward': 0.0 } logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters") def add_experience(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool, market_context: Dict[str, Any], cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str: """Add experience to the buffer""" try: experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}" experience = RLExperience( experience_id=experience_id, timestamp=datetime.now(), episode_id=market_context.get('episode_id', 'unknown'), state=state, action=action, reward=reward, next_state=next_state, done=done, market_context=market_context, cnn_predictions=cnn_predictions, confidence_score=confidence_score ) self.experience_buffer.add_experience(experience) self.training_stats['total_experiences'] += 1 return experience_id except Exception as e: logger.error(f"Error adding experience: {e}") return None def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]: """Train on experiences with comprehensive data storage""" try: session = RLTrainingSession( session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}", start_timestamp=datetime.now(), training_mode='experience_replay' ) self.current_session = session self.agent.train() total_loss = 0.0 for batch_idx in range(num_batches): experiences = self.experience_buffer.sample_batch(batch_size, True) if len(experiences) < batch_size: continue # Prepare batch tensors states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device) actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device) rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device) next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device) dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device) # Forward pass self.optimizer.zero_grad() current_outputs = self.agent(states) current_q_values = current_outputs['q_values'] # Calculate target Q-values with torch.no_grad(): next_outputs = self.agent(next_states) next_q_values = next_outputs['q_values'] max_next_q_values = torch.max(next_q_values, dim=1)[0] target_q_values = rewards + (self.gamma * max_next_q_values * ~dones) # Calculate loss current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) q_loss = F.mse_loss(current_q_values_for_actions, target_q_values) # Backward pass q_loss.backward() # Store gradients gradients = {} gradient_norms = {} for name, param in self.agent.named_parameters(): if param.grad is not None: gradients[name] = param.grad.clone().detach() gradient_norms[name] = param.grad.norm().item() torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0) self.optimizer.step() # Create training step record step = RLTrainingStep( step_id=f"{session.session_id}_step_{batch_idx}", timestamp=datetime.now(), batch_experiences=[exp.experience_id for exp in experiences], total_loss=q_loss.item(), q_loss=q_loss.item(), policy_loss=0.0, gradients=gradients, gradient_norms=gradient_norms, batch_size=len(experiences) ) session.training_steps.append(step) total_loss += q_loss.item() # Finalize session session.end_timestamp = datetime.now() session.total_steps = num_batches session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0 self._save_training_session(session) self.training_stats['total_sessions'] += 1 self.training_stats['total_steps'] += session.total_steps logger.info(f"RL training session completed: {session.session_id}") logger.info(f"Average loss: {session.average_loss:.4f}") return { 'status': 'success', 'session_id': session.session_id, 'average_loss': session.average_loss, 'total_steps': session.total_steps } except Exception as e: logger.error(f"Error in RL training session: {e}") return {'status': 'error', 'error': str(e)} finally: self.current_session = None def train_on_profitable_experiences(self, min_profitability: float = 0.1, max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]: """Train specifically on most profitable experiences""" try: profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences) filtered_experiences = [ exp for exp in profitable_experiences if exp.actual_profit is not None and exp.actual_profit >= min_profitability ] if len(filtered_experiences) < batch_size: return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)} logger.info(f"Training on {len(filtered_experiences)} profitable experiences") num_batches = len(filtered_experiences) // batch_size # Temporarily replace buffer sampling original_sample_method = self.experience_buffer.sample_batch def profitable_sample_batch(batch_size, prioritize_profitable=True): return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences))) self.experience_buffer.sample_batch = profitable_sample_batch try: results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches) results['training_mode'] = 'profitable_replay' results['experiences_used'] = len(filtered_experiences) return results finally: self.experience_buffer.sample_batch = original_sample_method except Exception as e: logger.error(f"Error training on profitable experiences: {e}") return {'status': 'error', 'error': str(e)} def _save_training_session(self, session: RLTrainingSession): """Save training session to disk""" try: session_dir = self.storage_dir / 'sessions' session_dir.mkdir(parents=True, exist_ok=True) session_file = session_dir / f"{session.session_id}.pkl" with open(session_file, 'wb') as f: pickle.dump(session, f) metadata = { 'session_id': session.session_id, 'start_timestamp': session.start_timestamp.isoformat(), 'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None, 'training_mode': session.training_mode, 'total_steps': session.total_steps, 'average_loss': session.average_loss } metadata_file = session_dir / f"{session.session_id}_metadata.json" with open(metadata_file, 'w') as f: json.dump(metadata, f, indent=2) except Exception as e: logger.error(f"Error saving training session: {e}") def get_training_statistics(self) -> Dict[str, Any]: """Get comprehensive training statistics""" stats = self.training_stats.copy() if self.training_sessions: recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10] stats['recent_sessions'] = [ { 'session_id': s.session_id, 'timestamp': s.start_timestamp.isoformat(), 'mode': s.training_mode, 'average_loss': s.average_loss } for s in recent_sessions ] return stats # Global instance rl_trainer = None def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer: """Get global RL trainer instance""" global rl_trainer if rl_trainer is None: if agent is None: agent = RLTradingAgent() rl_trainer = RLTrainer(agent) return rl_trainer