#!/usr/bin/env python3 """ Training Integration - Handles cold start training and model learning integration Manages: - Cold start training triggers from trade outcomes - Reward calculation based on P&L - Integration with DQN, CNN, and COB RL models - Training session management """ import logging from datetime import datetime from typing import Dict, List, Any, Optional import numpy as np from utils.reward_calculator import RewardCalculator import threading import time logger = logging.getLogger(__name__) class TrainingIntegration: """Manages training integration for cold start learning""" def __init__(self, orchestrator=None): self.orchestrator = orchestrator self.reward_calculator = RewardCalculator() self.training_sessions = {} self.min_confidence_threshold = 0.15 # Lowered from 0.3 for more aggressive training self.training_active = False self.trainer_thread = None self.stop_event = threading.Event() self.training_lock = threading.Lock() self.last_training_time = 0.0 if orchestrator is None else time.time() self.training_interval = 300 # 5 minutes between training sessions self.min_data_points = 100 # Minimum data points required to trigger training logger.info("TrainingIntegration initialized") def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool: """Trigger cold start training when trades close with known outcomes""" try: if not trade_record.get('model_inputs_at_entry'): logger.warning("No model inputs captured for training - skipping") return False pnl = trade_record.get('pnl', 0) confidence = trade_record.get('confidence', 0) logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}") # Calculate training reward based on P&L and confidence reward = self._calculate_training_reward(pnl, confidence) # Train DQN on trade outcome dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward) # Train CNN if available (placeholder for now) cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward) # Train COB RL if available (placeholder for now) cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward) # Log training results training_success = any([dqn_success, cnn_success, cob_success]) if training_success: logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}") else: logger.warning("Cold start training failed for all models") return training_success except Exception as e: logger.error(f"Error in cold start training: {e}") return False def _calculate_training_reward(self, pnl: float, confidence: float) -> float: """Calculate training reward based on P&L and confidence""" try: # Base reward is proportional to P&L base_reward = pnl # Adjust for confidence - penalize high confidence wrong predictions more if pnl < 0 and confidence > 0.7: # High confidence loss - significant negative reward confidence_adjustment = -confidence * 2 elif pnl > 0 and confidence > 0.7: # High confidence gain - boost reward confidence_adjustment = confidence * 1.5 else: # Low confidence - minimal adjustment confidence_adjustment = 0 final_reward = base_reward + confidence_adjustment # Normalize to [-1, 1] range for training stability normalized_reward = np.tanh(final_reward / 10.0) logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}") return float(normalized_reward) except Exception as e: logger.error(f"Error calculating training reward: {e}") return 0.0 def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool: """Train DQN agent on trade outcome""" try: if not self.orchestrator: logger.warning("No orchestrator available for DQN training") return False # Get DQN agent if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent: logger.warning("DQN agent not available for training") return False # Extract DQN state from model inputs model_inputs = trade_record.get('model_inputs_at_entry', {}) dqn_state = model_inputs.get('dqn_state', {}).get('state_vector') if not dqn_state: logger.warning("No DQN state available for training") return False # Convert action to DQN action index action = trade_record.get('side', 'HOLD').upper() action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2} action_idx = action_map.get(action, 2) # Create next state (simplified - could be current market state) next_state = dqn_state # Placeholder - should be state after trade # Store experience in DQN memory dqn_agent = self.orchestrator.dqn_agent if hasattr(dqn_agent, 'store_experience'): dqn_agent.store_experience( state=np.array(dqn_state), action=action_idx, reward=reward, next_state=np.array(next_state), done=True # Trade is complete ) # Trigger training if enough experiences if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32: dqn_agent.replay(batch_size=32) logger.info("DQN training step completed") return True else: logger.warning("DQN agent doesn't support experience storage") return False except Exception as e: logger.error(f"Error training DQN on trade outcome: {e}") return False def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool: """Train CNN on trade outcome with real implementation""" try: if not self.orchestrator: return False # Check if CNN is available cnn_model = None if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: cnn_model = self.orchestrator.cnn_model elif hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn: cnn_model = self.orchestrator.williams_cnn if not cnn_model: logger.debug("CNN not available for training") return False # Get CNN features from model inputs model_inputs = trade_record.get('model_inputs_at_entry', {}) cnn_features = model_inputs.get('cnn_features') if not cnn_features: logger.debug("No CNN features available for training") return False # Determine target based on trade outcome pnl = trade_record.get('pnl', 0) action = trade_record.get('side', 'HOLD').upper() # Create target based on trade success if pnl > 0: if action == 'BUY': target = 0 # Successful BUY elif action == 'SELL': target = 1 # Successful SELL else: target = 2 # HOLD else: # For unsuccessful trades, learn the opposite if action == 'BUY': target = 1 # Should have been SELL elif action == 'SELL': target = 0 # Should have been BUY else: target = 2 # HOLD # Initialize model attributes if needed if not hasattr(cnn_model, 'optimizer'): import torch cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001) # Perform actual CNN training try: import torch device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Prepare features if isinstance(cnn_features, list): features = np.array(cnn_features, dtype=np.float32) else: features = np.array(cnn_features, dtype=np.float32) # Ensure features are the right size if len(features) < 50: # Pad with zeros padded_features = np.zeros(50) padded_features[:len(features)] = features features = padded_features elif len(features) > 50: # Truncate features = features[:50] # Get the model's device to ensure tensors are on the same device model_device = next(cnn_model.parameters()).device # Create tensors features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device) target_tensor = torch.LongTensor([target]).to(model_device) # Training step cnn_model.train() cnn_model.optimizer.zero_grad() outputs = cnn_model(features_tensor) # Handle different output formats if isinstance(outputs, dict): if 'main_output' in outputs: logits = outputs['main_output'] elif 'action_logits' in outputs: logits = outputs['action_logits'] else: logits = list(outputs.values())[0] else: logits = outputs # Calculate loss with reward weighting loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(logits, target_tensor) # Weight loss by reward magnitude weighted_loss = loss * abs(reward) # Backward pass weighted_loss.backward() cnn_model.optimizer.step() logger.info(f"CNN trained on trade outcome: P&L=${pnl:.2f}, loss={loss.item():.4f}") return True except Exception as e: logger.error(f"Error in CNN training step: {e}") return False except Exception as e: logger.error(f"Error in CNN training: {e}") return False def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool: """Train COB RL on trade outcome with real implementation""" try: if not self.orchestrator: return False # Check if COB RL agent is available cob_rl_agent = None if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent: cob_rl_agent = self.orchestrator.rl_agent elif hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent: cob_rl_agent = self.orchestrator.cob_rl_agent if not cob_rl_agent: logger.debug("COB RL agent not available for training") return False # Get COB features from model inputs model_inputs = trade_record.get('model_inputs_at_entry', {}) cob_features = model_inputs.get('cob_features') if not cob_features: logger.debug("No COB features available for training") return False # Create state from COB features if isinstance(cob_features, list): state_features = np.array(cob_features, dtype=np.float32) else: state_features = np.array(cob_features, dtype=np.float32) # Pad or truncate to expected size if hasattr(cob_rl_agent, 'state_shape'): expected_size = cob_rl_agent.state_shape if isinstance(cob_rl_agent.state_shape, int) else cob_rl_agent.state_shape[0] else: expected_size = 100 # Default size if len(state_features) < expected_size: # Pad with zeros padded_features = np.zeros(expected_size) padded_features[:len(state_features)] = state_features state_features = padded_features elif len(state_features) > expected_size: # Truncate state_features = state_features[:expected_size] state = np.array(state_features, dtype=np.float32) # Determine action from trade record action_str = trade_record.get('side', 'HOLD').upper() if action_str == 'BUY': action = 0 elif action_str == 'SELL': action = 1 else: action = 2 # HOLD # Create next state (similar to current state for simplicity) next_state = state.copy() # Use PnL as reward pnl = trade_record.get('pnl', 0) actual_reward = float(pnl * 100) # Scale reward # Store experience in agent memory if hasattr(cob_rl_agent, 'remember'): cob_rl_agent.remember(state, action, actual_reward, next_state, done=True) elif hasattr(cob_rl_agent, 'store_experience'): cob_rl_agent.store_experience(state, action, actual_reward, next_state, done=True) # Perform training step if agent has replay method if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'): if len(cob_rl_agent.memory) > 32: # Enough samples to train loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory))) if loss is not None: logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}") return True logger.debug(f"COB RL experience stored: P&L=${pnl:.2f}, reward={actual_reward:.2f}") return True except Exception as e: logger.error(f"Error in COB RL training: {e}") return False def get_training_status(self) -> Dict[str, Any]: """Get current training status""" try: status = { 'active': self.training_active, 'last_training_time': self.last_training_time, 'training_sessions': self.training_sessions if self.training_sessions else {} } return status except Exception as e: logger.error(f"Error getting training status: {e}") return {} def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str: """Start a new training session""" try: session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" self.training_sessions[session_id] = { 'name': session_name, 'start_time': datetime.now(), 'config': config if config else {}, 'trades_processed': 0, 'training_attempts': 0, 'successful_trainings': 0 } logger.info(f"Started training session: {session_id}") return session_id except Exception as e: logger.error(f"Error starting training session: {e}") return "" def end_training_session(self, session_id: str) -> Dict[str, Any]: """End a training session and return summary""" try: if session_id not in self.training_sessions: logger.warning(f"Training session not found: {session_id}") return {} session_data = self.training_sessions[session_id] session_data['end_time'] = datetime.now().isoformat() # Calculate session duration start_time = datetime.fromisoformat(session_data['start_time']) end_time = datetime.fromisoformat(session_data['end_time']) duration = (end_time - start_time).total_seconds() session_data['duration_seconds'] = duration # Calculate success rate total_attempts = session_data['successful_trainings'] + session_data['failed_trainings'] session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0 logger.info(f"Ended training session: {session_id}") logger.info(f" Duration: {duration:.1f}s") logger.info(f" Trades processed: {session_data['trades_processed']}") logger.info(f" Success rate: {session_data['success_rate']:.2%}") # Remove from active sessions completed_session = self.training_sessions.pop(session_id) return completed_session except Exception as e: logger.error(f"Error ending training session: {e}") return {} def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False): """Update training session statistics""" try: if session_id not in self.training_sessions: return session = self.training_sessions[session_id] if trade_processed: session['trades_processed'] += 1 if training_success: session['successful_trainings'] += 1 else: session['failed_trainings'] += 1 except Exception as e: logger.error(f"Error updating session stats: {e}")