""" CNN Training Pipeline with Comprehensive Data Storage and Replay This module implements a robust CNN training pipeline that: 1. Integrates with the comprehensive training data collection system 2. Stores all backpropagation data for gradient replay 3. Enables retraining on most profitable setups 4. Maintains training episode profitability tracking 5. Supports both real-time and batch training modes Key Features: - Integration with TrainingDataCollector for data validation - Gradient and loss storage for each training step - Profitable episode prioritization and replay - Comprehensive training metrics and validation - Real-time pivot point prediction with outcome tracking """ import asyncio import logging import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any, Callable from dataclasses import dataclass, field import json import pickle from collections import deque, defaultdict import threading from concurrent.futures import ThreadPoolExecutor from .training_data_collector import ( TrainingDataCollector, TrainingEpisode, ModelInputPackage, get_training_data_collector ) logger = logging.getLogger(__name__) @dataclass class CNNTrainingStep: """Single CNN training step with complete backpropagation data""" step_id: str timestamp: datetime episode_id: str # Input data input_features: torch.Tensor target_labels: torch.Tensor # Forward pass results model_outputs: Dict[str, torch.Tensor] predictions: Dict[str, Any] confidence_scores: torch.Tensor # Loss components total_loss: float pivot_prediction_loss: float confidence_loss: float regularization_loss: float # Backpropagation data gradients: Dict[str, torch.Tensor] # Gradients for each parameter gradient_norms: Dict[str, float] # Gradient norms for monitoring # Model state model_state_dict: Optional[Dict[str, torch.Tensor]] = None optimizer_state: Optional[Dict[str, Any]] = None # Training metadata learning_rate: float = 0.001 batch_size: int = 32 epoch: int = 0 # Profitability tracking actual_profitability: Optional[float] = None prediction_accuracy: Optional[float] = None training_value: float = 0.0 # Value of this training step for replay @dataclass class CNNTrainingSession: """Complete CNN training session with multiple steps""" session_id: str start_timestamp: datetime end_timestamp: Optional[datetime] = None # Session configuration training_mode: str = 'real_time' # 'real_time', 'batch', 'replay' symbol: str = '' # Training steps training_steps: List[CNNTrainingStep] = field(default_factory=list) # Session metrics total_steps: int = 0 average_loss: float = 0.0 best_loss: float = float('inf') convergence_achieved: bool = False # Profitability metrics profitable_predictions: int = 0 total_predictions: int = 0 profitability_rate: float = 0.0 # Session value for replay prioritization session_value: float = 0.0 class CNNPivotPredictor(nn.Module): """CNN model for pivot point prediction with comprehensive output""" def __init__(self, input_channels: int = 10, # Multiple timeframes sequence_length: int = 300, # 300 bars hidden_dim: int = 256, num_pivot_classes: int = 3, # high, low, none dropout_rate: float = 0.2): super(CNNPivotPredictor, self).__init__() self.input_channels = input_channels self.sequence_length = sequence_length self.hidden_dim = hidden_dim # Convolutional layers for pattern extraction self.conv_layers = nn.Sequential( # First conv block nn.Conv1d(input_channels, 64, kernel_size=7, padding=3), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(dropout_rate), # Second conv block nn.Conv1d(64, 128, kernel_size=5, padding=2), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate), # Third conv block nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout_rate), ) # LSTM for temporal dependencies self.lstm = nn.LSTM( input_size=256, hidden_size=hidden_dim, num_layers=2, batch_first=True, dropout=dropout_rate, bidirectional=True ) # Attention mechanism self.attention = nn.MultiheadAttention( embed_dim=hidden_dim * 2, # Bidirectional LSTM num_heads=8, dropout=dropout_rate, batch_first=True ) # Output heads self.pivot_classifier = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, num_pivot_classes) ) self.pivot_price_regressor = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, 1) ) self.confidence_head = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid() ) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights with proper scaling""" if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv1d): torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') def forward(self, x): """ Forward pass through CNN pivot predictor Args: x: Input tensor [batch_size, input_channels, sequence_length] Returns: Dict containing predictions and hidden states """ batch_size = x.size(0) # Convolutional feature extraction conv_features = self.conv_layers(x) # [batch, 256, sequence_length] # Prepare for LSTM (transpose to [batch, sequence, features]) lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256] # LSTM processing lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2] # Attention mechanism attended_output, attention_weights = self.attention( lstm_output, lstm_output, lstm_output ) # Use the last timestep for predictions final_features = attended_output[:, -1, :] # [batch, hidden_dim*2] # Generate predictions pivot_logits = self.pivot_classifier(final_features) pivot_price = self.pivot_price_regressor(final_features) confidence = self.confidence_head(final_features) return { 'pivot_logits': pivot_logits, 'pivot_price': pivot_price, 'confidence': confidence, 'hidden_states': final_features, 'attention_weights': attention_weights, 'conv_features': conv_features, 'lstm_output': lstm_output } class CNNTrainingDataset(Dataset): """Dataset for CNN training with training episodes""" def __init__(self, training_episodes: List[TrainingEpisode]): self.episodes = training_episodes self.valid_episodes = self._validate_episodes() def _validate_episodes(self) -> List[TrainingEpisode]: """Validate and filter episodes for training""" valid = [] for episode in self.episodes: try: # Check if episode has required data if (episode.input_package.cnn_features is not None and episode.actual_outcome.outcome_validated): valid.append(episode) except Exception as e: logger.warning(f"Invalid episode {episode.episode_id}: {e}") logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training") return valid def __len__(self): return len(self.valid_episodes) def __getitem__(self, idx): episode = self.valid_episodes[idx] # Extract features features = torch.from_numpy(episode.input_package.cnn_features).float() # Create labels from actual outcomes pivot_class = self._determine_pivot_class(episode.actual_outcome) pivot_price = episode.actual_outcome.optimal_exit_price confidence_target = episode.actual_outcome.profitability_score return { 'features': features, 'pivot_class': torch.tensor(pivot_class, dtype=torch.long), 'pivot_price': torch.tensor(pivot_price, dtype=torch.float), 'confidence_target': torch.tensor(confidence_target, dtype=torch.float), 'episode_id': episode.episode_id, 'profitability': episode.actual_outcome.profitability_score } def _determine_pivot_class(self, outcome) -> int: """Determine pivot class from outcome""" if outcome.price_change_15m > 0.5: # Significant upward movement return 0 # High pivot elif outcome.price_change_15m < -0.5: # Significant downward movement return 1 # Low pivot else: return 2 # No significant pivot class CNNTrainer: """CNN trainer with comprehensive data storage and replay capabilities""" def __init__(self, model: CNNPivotPredictor, device: str = 'cuda', learning_rate: float = 0.001, storage_dir: str = "cnn_training_storage"): self.model = model.to(device) self.device = device self.learning_rate = learning_rate # Storage self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(parents=True, exist_ok=True) # Optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=learning_rate, weight_decay=1e-5 ) # Learning rate scheduler self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', patience=10, factor=0.5 ) # Training data collector self.data_collector = get_training_data_collector() # Training sessions storage self.training_sessions: List[CNNTrainingSession] = [] self.current_session: Optional[CNNTrainingSession] = None # Training statistics self.training_stats = { 'total_sessions': 0, 'total_steps': 0, 'best_validation_loss': float('inf'), 'profitable_predictions': 0, 'total_predictions': 0, 'replay_sessions': 0 } # Background training self.is_training = False self.training_thread = None logger.info(f"CNN Trainer initialized") logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}") logger.info(f"Storage directory: {self.storage_dir}") def start_real_time_training(self, symbol: str): """Start real-time training for a symbol""" if self.is_training: logger.warning("CNN training already running") return self.is_training = True self.training_thread = threading.Thread( target=self._real_time_training_worker, args=(symbol,), daemon=True ) self.training_thread.start() logger.info(f"Started real-time CNN training for {symbol}") def stop_training(self): """Stop training""" self.is_training = False if self.training_thread: self.training_thread.join(timeout=10) if self.current_session: self._finalize_training_session() logger.info("CNN training stopped") def _real_time_training_worker(self, symbol: str): """Real-time training worker""" logger.info(f"Real-time CNN training worker started for {symbol}") while self.is_training: try: # Get high-priority episodes for training episodes = self.data_collector.get_high_priority_episodes( symbol=symbol, limit=100, min_priority=0.3 ) if len(episodes) >= 32: # Minimum batch size self._train_on_episodes(episodes, training_mode='real_time') # Wait before next training cycle threading.Event().wait(300) # Train every 5 minutes except Exception as e: logger.error(f"Error in real-time training worker: {e}") threading.Event().wait(60) # Wait before retrying logger.info(f"Real-time CNN training worker stopped for {symbol}") def train_on_profitable_episodes(self, symbol: str, min_profitability: float = 0.7, max_episodes: int = 500) -> Dict[str, Any]: """Train specifically on most profitable episodes""" try: # Get all episodes for symbol all_episodes = self.data_collector.training_episodes.get(symbol, []) # Filter for profitable episodes profitable_episodes = [ ep for ep in all_episodes if (ep.actual_outcome.is_profitable and ep.actual_outcome.profitability_score >= min_profitability) ] # Sort by profitability and limit profitable_episodes.sort( key=lambda x: x.actual_outcome.profitability_score, reverse=True ) profitable_episodes = profitable_episodes[:max_episodes] if len(profitable_episodes) < 10: logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}") return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)} # Train on profitable episodes results = self._train_on_episodes( profitable_episodes, training_mode='profitable_replay' ) logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}") return results except Exception as e: logger.error(f"Error training on profitable episodes: {e}") return {'status': 'error', 'error': str(e)} def _train_on_episodes(self, episodes: List[TrainingEpisode], training_mode: str = 'batch') -> Dict[str, Any]: """Train on a batch of episodes with comprehensive data storage""" try: # Start new training session session = CNNTrainingSession( session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", start_timestamp=datetime.now(), training_mode=training_mode, symbol=episodes[0].input_package.symbol if episodes else 'unknown' ) self.current_session = session # Create dataset and dataloader dataset = CNNTrainingDataset(episodes) dataloader = DataLoader( dataset, batch_size=32, shuffle=True, num_workers=2 ) # Training loop self.model.train() total_loss = 0.0 num_batches = 0 for batch_idx, batch in enumerate(dataloader): # Move to device features = batch['features'].to(self.device) pivot_class = batch['pivot_class'].to(self.device) pivot_price = batch['pivot_price'].to(self.device) confidence_target = batch['confidence_target'].to(self.device) # Forward pass self.optimizer.zero_grad() outputs = self.model(features) # Calculate losses classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class) regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price) confidence_loss = F.binary_cross_entropy( outputs['confidence'].squeeze(), confidence_target ) # Combined loss total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss # Backward pass total_batch_loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # Store gradients before optimizer step gradients = {} gradient_norms = {} for name, param in self.model.named_parameters(): if param.grad is not None: gradients[name] = param.grad.clone().detach() gradient_norms[name] = param.grad.norm().item() # Optimizer step self.optimizer.step() # Create training step record step = CNNTrainingStep( step_id=f"{session.session_id}_step_{batch_idx}", timestamp=datetime.now(), episode_id=f"batch_{batch_idx}", input_features=features.detach().cpu(), target_labels=pivot_class.detach().cpu(), model_outputs={k: v.detach().cpu() for k, v in outputs.items()}, predictions=self._extract_predictions(outputs), confidence_scores=outputs['confidence'].detach().cpu(), total_loss=total_batch_loss.item(), pivot_prediction_loss=classification_loss.item(), confidence_loss=confidence_loss.item(), regularization_loss=0.0, gradients=gradients, gradient_norms=gradient_norms, learning_rate=self.optimizer.param_groups[0]['lr'], batch_size=features.size(0) ) # Calculate training value for this step step.training_value = self._calculate_step_training_value(step, batch) # Add to session session.training_steps.append(step) total_loss += total_batch_loss.item() num_batches += 1 # Log progress if batch_idx % 10 == 0: logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}") # Finalize session session.end_timestamp = datetime.now() session.total_steps = num_batches session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0 session.best_loss = min(step.total_loss for step in session.training_steps) # Calculate session value session.session_value = self._calculate_session_value(session) # Update scheduler self.scheduler.step(session.average_loss) # Save session self._save_training_session(session) # Update statistics self.training_stats['total_sessions'] += 1 self.training_stats['total_steps'] += session.total_steps if training_mode == 'profitable_replay': self.training_stats['replay_sessions'] += 1 logger.info(f"Training session completed: {session.session_id}") logger.info(f"Average loss: {session.average_loss:.4f}") logger.info(f"Session value: {session.session_value:.3f}") return { 'status': 'success', 'session_id': session.session_id, 'average_loss': session.average_loss, 'total_steps': session.total_steps, 'session_value': session.session_value } except Exception as e: logger.error(f"Error in training session: {e}") return {'status': 'error', 'error': str(e)} finally: self.current_session = None def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: """Extract human-readable predictions from model outputs""" try: pivot_probs = F.softmax(outputs['pivot_logits'], dim=1) predicted_class = torch.argmax(pivot_probs, dim=1) return { 'pivot_class': predicted_class.cpu().numpy().tolist(), 'pivot_probabilities': pivot_probs.cpu().numpy().tolist(), 'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(), 'confidence': outputs['confidence'].cpu().numpy().tolist() } except Exception as e: logger.warning(f"Error extracting predictions: {e}") return {} def _calculate_step_training_value(self, step: CNNTrainingStep, batch: Dict[str, Any]) -> float: """Calculate the training value of a step for replay prioritization""" try: value = 0.0 # Base value from loss (lower loss = higher value) if step.total_loss > 0: value += 1.0 / (1.0 + step.total_loss) # Bonus for high profitability episodes in batch avg_profitability = torch.mean(batch['profitability']).item() value += avg_profitability * 0.3 # Bonus for gradient magnitude (indicates learning) avg_grad_norm = np.mean(list(step.gradient_norms.values())) value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2 return min(value, 1.0) except Exception as e: logger.warning(f"Error calculating step training value: {e}") return 0.0 def _calculate_session_value(self, session: CNNTrainingSession) -> float: """Calculate overall session value for replay prioritization""" try: if not session.training_steps: return 0.0 # Average step values avg_step_value = np.mean([step.training_value for step in session.training_steps]) # Bonus for convergence convergence_bonus = 0.0 if len(session.training_steps) > 10: early_loss = np.mean([s.total_loss for s in session.training_steps[:5]]) late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]]) if early_loss > late_loss: convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3) # Bonus for profitable replay sessions mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0 return min(avg_step_value + convergence_bonus + mode_bonus, 1.0) except Exception as e: logger.warning(f"Error calculating session value: {e}") return 0.0 def _save_training_session(self, session: CNNTrainingSession): """Save training session to disk""" try: session_dir = self.storage_dir / session.symbol / 'sessions' session_dir.mkdir(parents=True, exist_ok=True) # Save full session data session_file = session_dir / f"{session.session_id}.pkl" with open(session_file, 'wb') as f: pickle.dump(session, f) # Save session metadata metadata = { 'session_id': session.session_id, 'start_timestamp': session.start_timestamp.isoformat(), 'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None, 'training_mode': session.training_mode, 'symbol': session.symbol, 'total_steps': session.total_steps, 'average_loss': session.average_loss, 'best_loss': session.best_loss, 'session_value': session.session_value } metadata_file = session_dir / f"{session.session_id}_metadata.json" with open(metadata_file, 'w') as f: json.dump(metadata, f, indent=2) logger.debug(f"Saved training session: {session.session_id}") except Exception as e: logger.error(f"Error saving training session: {e}") def _finalize_training_session(self): """Finalize current training session""" if self.current_session: self.current_session.end_timestamp = datetime.now() self._save_training_session(self.current_session) self.training_sessions.append(self.current_session) self.current_session = None def get_training_statistics(self) -> Dict[str, Any]: """Get comprehensive training statistics""" stats = self.training_stats.copy() # Add recent session information if self.training_sessions: recent_sessions = sorted( self.training_sessions, key=lambda x: x.start_timestamp, reverse=True )[:10] stats['recent_sessions'] = [ { 'session_id': s.session_id, 'timestamp': s.start_timestamp.isoformat(), 'mode': s.training_mode, 'average_loss': s.average_loss, 'session_value': s.session_value } for s in recent_sessions ] # Calculate profitability rate if stats['total_predictions'] > 0: stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions'] else: stats['profitability_rate'] = 0.0 return stats def replay_high_value_sessions(self, symbol: str, min_session_value: float = 0.7, max_sessions: int = 10) -> Dict[str, Any]: """Replay high-value training sessions""" try: # Find high-value sessions high_value_sessions = [ s for s in self.training_sessions if (s.symbol == symbol and s.session_value >= min_session_value) ] # Sort by value and limit high_value_sessions.sort(key=lambda x: x.session_value, reverse=True) high_value_sessions = high_value_sessions[:max_sessions] if not high_value_sessions: return {'status': 'no_high_value_sessions', 'sessions_found': 0} # Replay sessions total_replayed = 0 for session in high_value_sessions: # Extract episodes from session steps episode_ids = list(set(step.episode_id for step in session.training_steps)) # Get corresponding episodes episodes = [] for episode_id in episode_ids: # Find episode in data collector for ep in self.data_collector.training_episodes.get(symbol, []): if ep.episode_id == episode_id: episodes.append(ep) break if episodes: self._train_on_episodes(episodes, training_mode='high_value_replay') total_replayed += 1 logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}") return { 'status': 'success', 'sessions_replayed': total_replayed, 'sessions_found': len(high_value_sessions) } except Exception as e: logger.error(f"Error replaying high-value sessions: {e}") return {'status': 'error', 'error': str(e)} # Global instance cnn_trainer = None def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer: """Get global CNN trainer instance""" global cnn_trainer if cnn_trainer is None: if model is None: model = CNNPivotPredictor() cnn_trainer = CNNTrainer(model) return cnn_trainer