""" Retrospective Training System This module implements a retrospective training system that: 1. Triggers training when trades close with known P&L outcomes 2. Uses captured model inputs from trade entry to train models 3. Optimizes for profit by learning from profitable vs unprofitable patterns 4. Supports simultaneous inference and training without weight reloading 5. Implements reinforcement learning with immediate reward feedback """ import logging import threading import time import queue from datetime import datetime from typing import Dict, List, Optional, Any, Callable from dataclasses import dataclass import numpy as np from collections import deque logger = logging.getLogger(__name__) @dataclass class TrainingCase: """Represents a completed trade case for retrospective training""" case_id: str symbol: str action: str # 'BUY' or 'SELL' entry_price: float exit_price: float entry_time: datetime exit_time: datetime pnl: float fees: float confidence: float model_inputs: Dict[str, Any] market_state: Dict[str, Any] outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven reward_signal: float # Scaled reward for RL training leverage: float = 1.0 class RetrospectiveTrainer: """Retrospective training system for real-time model optimization""" def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None): """Initialize the retrospective trainer""" self.orchestrator = orchestrator self.config = config or {} # Training configuration self.batch_size = self.config.get('batch_size', 32) self.min_cases_for_training = self.config.get('min_cases_for_training', 5) self.profit_threshold = self.config.get('profit_threshold', 0.0) self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes self.max_training_cases = self.config.get('max_training_cases', 1000) # Training state self.training_queue = queue.Queue() self.completed_cases = deque(maxlen=self.max_training_cases) self.training_stats = { 'total_cases': 0, 'profitable_cases': 0, 'loss_cases': 0, 'breakeven_cases': 0, 'avg_profit': 0.0, 'last_training_time': datetime.now(), 'training_sessions': 0, 'model_updates': 0 } # Threading self.training_thread = None self.is_training_active = False self.training_lock = threading.Lock() logger.info("RetrospectiveTrainer initialized") logger.info(f"Configuration: batch_size={self.batch_size}, " f"min_cases={self.min_cases_for_training}, " f"training_freq={self.training_frequency}s") def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool: """Add a completed trade for retrospective training""" try: # Create training case from trade record case = self._create_training_case(trade_record, model_inputs) if case is None: return False # Add to completed cases self.completed_cases.append(case) self.training_queue.put(case) # Update statistics self.training_stats['total_cases'] += 1 if case.outcome_label == 1: # Profit self.training_stats['profitable_cases'] += 1 elif case.outcome_label == 0: # Loss self.training_stats['loss_cases'] += 1 else: # Breakeven self.training_stats['breakeven_cases'] += 1 # Calculate running average profit total_pnl = sum(c.pnl for c in self.completed_cases) self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases) logger.info(f"RETROSPECTIVE: Added training case {case.case_id} " f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})") # Trigger training if we have enough cases self._maybe_trigger_training() return True except Exception as e: logger.error(f"Error adding completed trade for retrospective training: {e}") return False def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]: """Create a training case from trade record and model inputs""" try: # Extract trade information symbol = trade_record.get('symbol', 'UNKNOWN') side = trade_record.get('side', 'UNKNOWN') pnl = trade_record.get('pnl', 0.0) fees = trade_record.get('fees', 0.0) confidence = trade_record.get('confidence', 0.0) # Calculate net P&L after fees net_pnl = pnl - fees # Determine outcome label and reward signal if net_pnl > self.profit_threshold: outcome_label = 1 # Profitable # Scale reward by profit magnitude and confidence reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training elif net_pnl < -self.profit_threshold: outcome_label = 0 # Loss # Negative reward scaled by loss magnitude reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward else: outcome_label = 2 # Breakeven reward_signal = 0.0 # Create case ID timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S') case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p') # Create training case case = TrainingCase( case_id=case_id, symbol=symbol, action=side, entry_price=trade_record.get('entry_price', 0.0), exit_price=trade_record.get('exit_price', 0.0), entry_time=trade_record.get('entry_time', datetime.now()), exit_time=trade_record.get('exit_time', datetime.now()), pnl=net_pnl, fees=fees, confidence=confidence, model_inputs=model_inputs, market_state=model_inputs.get('market_state', {}), outcome_label=outcome_label, reward_signal=reward_signal, leverage=trade_record.get('leverage', 1.0) ) return case except Exception as e: logger.error(f"Error creating training case: {e}") return None def _maybe_trigger_training(self): """Check if we should trigger a training session""" try: # Check if we have enough cases if len(self.completed_cases) < self.min_cases_for_training: return # Check if enough time has passed since last training time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds() if time_since_last < self.training_frequency: return # Check if training thread is not already running if self.is_training_active: logger.debug("Training already in progress, skipping trigger") return # Start training in background thread self._start_training_session() except Exception as e: logger.error(f"Error checking training trigger: {e}") def _start_training_session(self): """Start a training session in background thread""" try: if self.training_thread and self.training_thread.is_alive(): logger.debug("Training thread already running") return self.training_thread = threading.Thread( target=self._run_training_session, daemon=True, name="RetrospectiveTrainer" ) self.training_thread.start() logger.info("RETROSPECTIVE: Started training session") except Exception as e: logger.error(f"Error starting training session: {e}") def _run_training_session(self): """Run a complete training session""" try: with self.training_lock: self.is_training_active = True start_time = time.time() logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases") # Train models if orchestrator available training_results = {} if self.orchestrator: training_results = self._train_models() # Update statistics self.training_stats['last_training_time'] = datetime.now() self.training_stats['training_sessions'] += 1 self.training_stats['model_updates'] += len(training_results) elapsed_time = time.time() - start_time logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}") except Exception as e: logger.error(f"Error in retrospective training session: {e}") import traceback logger.error(traceback.format_exc()) finally: self.is_training_active = False def _train_models(self) -> Dict[str, Any]: """Train available models using retrospective data""" results = {} try: # Prepare training data profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1] loss_cases = [c for c in self.completed_cases if c.outcome_label == 0] if len(profitable_cases) == 0 and len(loss_cases) == 0: return {'error': 'No labeled cases for training'} logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}") # Train DQN agent if available if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent: try: dqn_result = self._train_dqn_retrospective() results['dqn'] = dqn_result logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}") except Exception as e: logger.warning(f"DQN retrospective training failed: {e}") results['dqn'] = {'error': str(e)} # Train other models if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer: try: # Update extrema trainer with retrospective feedback extrema_feedback = self._create_extrema_feedback() if extrema_feedback: results['extrema'] = {'feedback_cases': len(extrema_feedback)} logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases") except Exception as e: logger.warning(f"Extrema retrospective training failed: {e}") return results except Exception as e: logger.error(f"Error training models retrospectively: {e}") return {'error': str(e)} def _train_dqn_retrospective(self) -> Dict[str, Any]: """Train DQN agent using retrospective experience replay""" try: if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent: return {'error': 'DQN agent not available'} dqn_agent = self.orchestrator.rl_agent experiences_added = 0 # Add retrospective experiences to DQN replay buffer for case in self.completed_cases: try: # Extract state from model inputs state = self._extract_state_vector(case.model_inputs) if state is None: continue # Action mapping: BUY=0, SELL=1 action = 0 if case.action == 'BUY' else 1 # Use reward signal as immediate reward reward = case.reward_signal # For retrospective training, next_state is None (terminal) next_state = np.zeros_like(state) # Terminal state done = True # Add experience to DQN replay buffer if hasattr(dqn_agent, 'add_experience'): dqn_agent.add_experience(state, action, reward, next_state, done) experiences_added += 1 except Exception as e: logger.debug(f"Error adding DQN experience: {e}") continue # Train DQN if we have enough experiences if experiences_added > 0 and hasattr(dqn_agent, 'train'): try: # Perform multiple training steps on retrospective data training_steps = min(10, experiences_added // 4) # Conservative training for _ in range(training_steps): loss = dqn_agent.train() if loss is None: break return { 'experiences_added': experiences_added, 'training_steps': training_steps, 'method': 'retrospective_experience_replay' } except Exception as e: logger.warning(f"DQN training step failed: {e}") return {'experiences_added': experiences_added, 'training_error': str(e)} return {'experiences_added': experiences_added, 'training_steps': 0} except Exception as e: logger.error(f"Error in DQN retrospective training: {e}") return {'error': str(e)} def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]: """Extract state vector for DQN training from model inputs""" try: # Try to get pre-built RL state if 'dqn_state' in model_inputs: state = model_inputs['dqn_state'] if isinstance(state, dict) and 'state_vector' in state: return np.array(state['state_vector']) # Build state from market features market_state = model_inputs.get('market_state', {}) features = [] # Price features for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']: features.append(market_state.get(key, 0.0)) # Volume features for key in ['volume_current', 'volume_sma_20', 'volume_ratio']: features.append(market_state.get(key, 0.0)) # Technical indicators indicators = model_inputs.get('technical_indicators', {}) for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']: features.append(indicators.get(key, 0.0)) if len(features) < 5: # Minimum required features return None return np.array(features, dtype=np.float32) except Exception as e: logger.debug(f"Error extracting state vector: {e}") return None def _create_extrema_feedback(self) -> List[Dict[str, Any]]: """Create feedback data for extrema trainer""" feedback = [] try: for case in self.completed_cases: if case.outcome_label in [0, 1]: # Only profit/loss cases feedback_item = { 'symbol': case.symbol, 'action': case.action, 'entry_price': case.entry_price, 'exit_price': case.exit_price, 'was_profitable': case.outcome_label == 1, 'reward_signal': case.reward_signal, 'market_state': case.market_state } feedback.append(feedback_item) return feedback except Exception as e: logger.error(f"Error creating extrema feedback: {e}") return [] def get_training_stats(self) -> Dict[str, Any]: """Get current training statistics""" stats = self.training_stats.copy() stats['total_cases_in_memory'] = len(self.completed_cases) stats['training_queue_size'] = self.training_queue.qsize() stats['is_training_active'] = self.is_training_active # Calculate profit metrics if len(self.completed_cases) > 0: profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0) stats['profit_rate'] = profitable_count / len(self.completed_cases) stats['total_pnl'] = sum(c.pnl for c in self.completed_cases) stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases) return stats def force_training_session(self) -> bool: """Force a training session regardless of timing constraints""" try: if self.is_training_active: logger.warning("Training already in progress") return False if len(self.completed_cases) < 1: logger.warning("No completed cases available for training") return False logger.info("RETROSPECTIVE: Forcing training session") self._start_training_session() return True except Exception as e: logger.error(f"Error forcing training session: {e}") return False def stop(self): """Stop the retrospective trainer""" try: self.is_training_active = False if self.training_thread and self.training_thread.is_alive(): self.training_thread.join(timeout=10) logger.info("RetrospectiveTrainer stopped") except Exception as e: logger.error(f"Error stopping RetrospectiveTrainer: {e}") def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer: """Factory function to create a RetrospectiveTrainer instance""" return RetrospectiveTrainer(orchestrator=orchestrator, config=config)