#!/usr/bin/env python3 """ Training Integration - Handles cold start training and model learning integration Manages: - Cold start training triggers from trade outcomes - Reward calculation based on P&L - Integration with DQN, CNN, and COB RL models - Training session management """ import logging from datetime import datetime from typing import Dict, List, Any, Optional import numpy as np logger = logging.getLogger(__name__) class TrainingIntegration: """Manages training integration for cold start learning""" def __init__(self, orchestrator=None): self.orchestrator = orchestrator self.training_sessions = {} self.min_confidence_threshold = 0.3 logger.info("TrainingIntegration initialized") def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool: """Trigger cold start training when trades close with known outcomes""" try: if not trade_record.get('model_inputs_at_entry'): logger.warning("No model inputs captured for training - skipping") return False pnl = trade_record.get('pnl', 0) confidence = trade_record.get('confidence', 0) logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}") # Calculate training reward based on P&L and confidence reward = self._calculate_training_reward(pnl, confidence) # Train DQN on trade outcome dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward) # Train CNN if available (placeholder for now) cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward) # Train COB RL if available (placeholder for now) cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward) # Log training results training_success = any([dqn_success, cnn_success, cob_success]) if training_success: logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}") else: logger.warning("Cold start training failed for all models") return training_success except Exception as e: logger.error(f"Error in cold start training: {e}") return False def _calculate_training_reward(self, pnl: float, confidence: float) -> float: """Calculate training reward based on P&L and confidence""" try: # Base reward is proportional to P&L base_reward = pnl # Adjust for confidence - penalize high confidence wrong predictions more if pnl < 0 and confidence > 0.7: # High confidence loss - significant negative reward confidence_adjustment = -confidence * 2 elif pnl > 0 and confidence > 0.7: # High confidence gain - boost reward confidence_adjustment = confidence * 1.5 else: # Low confidence - minimal adjustment confidence_adjustment = 0 final_reward = base_reward + confidence_adjustment # Normalize to [-1, 1] range for training stability normalized_reward = np.tanh(final_reward / 10.0) logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}") return float(normalized_reward) except Exception as e: logger.error(f"Error calculating training reward: {e}") return 0.0 def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool: """Train DQN agent on trade outcome""" try: if not self.orchestrator: logger.warning("No orchestrator available for DQN training") return False # Get DQN agent if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent: logger.warning("DQN agent not available for training") return False # Extract DQN state from model inputs model_inputs = trade_record.get('model_inputs_at_entry', {}) dqn_state = model_inputs.get('dqn_state', {}).get('state_vector') if not dqn_state: logger.warning("No DQN state available for training") return False # Convert action to DQN action index action = trade_record.get('side', 'HOLD').upper() action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2} action_idx = action_map.get(action, 2) # Create next state (simplified - could be current market state) next_state = dqn_state # Placeholder - should be state after trade # Store experience in DQN memory dqn_agent = self.orchestrator.dqn_agent if hasattr(dqn_agent, 'store_experience'): dqn_agent.store_experience( state=np.array(dqn_state), action=action_idx, reward=reward, next_state=np.array(next_state), done=True # Trade is complete ) # Trigger training if enough experiences if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32: dqn_agent.replay(batch_size=32) logger.info("DQN training step completed") return True else: logger.warning("DQN agent doesn't support experience storage") return False except Exception as e: logger.error(f"Error training DQN on trade outcome: {e}") return False def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool: """Train CNN on trade outcome (placeholder)""" try: if not self.orchestrator: return False # Check if CNN is available if not hasattr(self.orchestrator, 'williams_cnn') or not self.orchestrator.williams_cnn: logger.debug("CNN not available for training") return False # Get CNN features from model inputs model_inputs = trade_record.get('model_inputs_at_entry', {}) cnn_features = model_inputs.get('cnn_features') cnn_predictions = model_inputs.get('cnn_predictions') if not cnn_features or not cnn_predictions: logger.debug("No CNN features available for training") return False # CNN training would go here - requires more specific implementation # For now, just log that we could train CNN logger.debug(f"CNN training opportunity: features={len(cnn_features)}, predictions={len(cnn_predictions)}") return True except Exception as e: logger.debug(f"Error in CNN training: {e}") return False def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool: """Train COB RL on trade outcome (placeholder)""" try: if not self.orchestrator: return False # Check if COB integration is available if not hasattr(self.orchestrator, 'cob_integration') or not self.orchestrator.cob_integration: logger.debug("COB integration not available for training") return False # Get COB features from model inputs model_inputs = trade_record.get('model_inputs_at_entry', {}) cob_features = model_inputs.get('cob_features') if not cob_features: logger.debug("No COB features available for training") return False # COB RL training would go here - requires more specific implementation # For now, just log that we could train COB RL logger.debug(f"COB RL training opportunity: features={len(cob_features)}") return True except Exception as e: logger.debug(f"Error in COB RL training: {e}") return False def get_training_status(self) -> Dict[str, Any]: """Get current training integration status""" try: status = { 'orchestrator_available': self.orchestrator is not None, 'training_sessions': len(self.training_sessions), 'last_update': datetime.now().isoformat() } if self.orchestrator: status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None return status except Exception as e: logger.error(f"Error getting training status: {e}") return {'error': str(e)} def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str: """Start a new training session""" try: session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" session_data = { 'session_id': session_id, 'session_name': session_name, 'start_time': datetime.now().isoformat(), 'config': config or {}, 'trades_processed': 0, 'successful_trainings': 0, 'failed_trainings': 0 } self.training_sessions[session_id] = session_data logger.info(f"Started training session: {session_id}") return session_id except Exception as e: logger.error(f"Error starting training session: {e}") return "" def end_training_session(self, session_id: str) -> Dict[str, Any]: """End a training session and return summary""" try: if session_id not in self.training_sessions: logger.warning(f"Training session not found: {session_id}") return {} session_data = self.training_sessions[session_id] session_data['end_time'] = datetime.now().isoformat() # Calculate session duration start_time = datetime.fromisoformat(session_data['start_time']) end_time = datetime.fromisoformat(session_data['end_time']) duration = (end_time - start_time).total_seconds() session_data['duration_seconds'] = duration # Calculate success rate total_attempts = session_data['successful_trainings'] + session_data['failed_trainings'] session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0 logger.info(f"Ended training session: {session_id}") logger.info(f" Duration: {duration:.1f}s") logger.info(f" Trades processed: {session_data['trades_processed']}") logger.info(f" Success rate: {session_data['success_rate']:.2%}") # Remove from active sessions completed_session = self.training_sessions.pop(session_id) return completed_session except Exception as e: logger.error(f"Error ending training session: {e}") return {} def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False): """Update training session statistics""" try: if session_id not in self.training_sessions: return session = self.training_sessions[session_id] if trade_processed: session['trades_processed'] += 1 if training_success: session['successful_trainings'] += 1 else: session['failed_trainings'] += 1 except Exception as e: logger.error(f"Error updating session stats: {e}")