#!/usr/bin/env python3 """ Multi-Horizon Trainer This module trains models using stored prediction snapshots when outcomes are known. It handles training for different time horizons and model types. """ import logging import threading import time from datetime import datetime, timedelta from typing import Dict, List, Any, Optional, Tuple import numpy as np import torch from collections import defaultdict from .prediction_snapshot_storage import PredictionSnapshotStorage from .multi_horizon_prediction_manager import PredictionSnapshot logger = logging.getLogger(__name__) class MultiHorizonTrainer: """Trainer for multi-horizon predictions using stored snapshots""" def __init__(self, orchestrator=None, snapshot_storage: Optional[PredictionSnapshotStorage] = None): """Initialize the multi-horizon trainer""" self.orchestrator = orchestrator self.snapshot_storage = snapshot_storage or PredictionSnapshotStorage() # Training configuration self.batch_size = 32 self.min_batch_size = 10 self.training_interval_seconds = 300 # 5 minutes self.max_training_age_hours = 24 # Don't train on predictions older than 24 hours # Model training settings self.learning_rate = 0.001 self.epochs_per_batch = 5 self.validation_split = 0.2 # Training state self.training_active = False self.training_thread = None self.last_training_time = 0.0 # Performance tracking self.training_stats = { 'total_training_sessions': 0, 'models_trained': defaultdict(int), 'training_accuracy': defaultdict(list), 'loss_history': defaultdict(list), 'last_training_time': None } logger.info("MultiHorizonTrainer initialized") def start(self): """Start the training system""" if self.training_active: logger.warning("Training system already active") return self.training_active = True self.training_thread = threading.Thread( target=self._training_loop, daemon=True, name="MultiHorizonTrainer" ) self.training_thread.start() logger.info("MultiHorizonTrainer started") def stop(self): """Stop the training system""" self.training_active = False if self.training_thread and self.training_thread.is_alive(): self.training_thread.join(timeout=10) logger.info("MultiHorizonTrainer stopped") def _training_loop(self): """Main training loop""" while self.training_active: try: current_time = time.time() # Check if it's time for training if current_time - self.last_training_time >= self.training_interval_seconds: self._run_training_session() self.last_training_time = current_time # Sleep before next check time.sleep(60) # Check every minute except Exception as e: logger.error(f"Error in training loop: {e}") time.sleep(300) # Longer sleep on error def _run_training_session(self): """Run a complete training session""" try: logger.info("Starting multi-horizon training session") training_results = {} # Train each horizon separately horizons = [1, 5, 15, 60] symbols = ['ETH/USDT', 'BTC/USDT'] for horizon in horizons: for symbol in symbols: try: horizon_results = self._train_horizon_models(horizon, symbol) if horizon_results: training_results[f"{horizon}m_{symbol}"] = horizon_results except Exception as e: logger.error(f"Error training {horizon}m models for {symbol}: {e}") # Update statistics self.training_stats['total_training_sessions'] += 1 self.training_stats['last_training_time'] = datetime.now() if training_results: logger.info(f"Training session completed: {len(training_results)} model updates") for key, results in training_results.items(): logger.info(f" {key}: {results}") else: logger.debug("No models were trained in this session") except Exception as e: logger.error(f"Error in training session: {e}") def _train_horizon_models(self, horizon_minutes: int, symbol: str) -> Dict[str, Any]: """Train models for a specific horizon and symbol""" results = {} # Get training batch snapshots = self.snapshot_storage.get_training_batch( horizon_minutes=horizon_minutes, symbol=symbol, batch_size=self.batch_size, min_confidence=0.3 ) if len(snapshots) < self.min_batch_size: logger.debug(f"Insufficient training data for {horizon_minutes}m {symbol}: {len(snapshots)} snapshots") return results logger.info(f"Training {horizon_minutes}m models for {symbol} with {len(snapshots)} snapshots") # Train CNN model if self.orchestrator and hasattr(self.orchestrator, 'cnn_model'): try: cnn_results = self._train_cnn_model(snapshots, horizon_minutes, symbol) if cnn_results: results['cnn'] = cnn_results self.training_stats['models_trained']['cnn'] += 1 except Exception as e: logger.error(f"CNN training failed for {horizon_minutes}m {symbol}: {e}") # Train RL model if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'): try: rl_results = self._train_rl_model(snapshots, horizon_minutes, symbol) if rl_results: results['rl'] = rl_results self.training_stats['models_trained']['rl'] += 1 except Exception as e: logger.error(f"RL training failed for {horizon_minutes}m {symbol}: {e}") return results def _train_cnn_model(self, snapshots: List[PredictionSnapshot], horizon_minutes: int, symbol: str) -> Dict[str, Any]: """Train CNN model using prediction snapshots""" try: if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model'): return None cnn_model = self.orchestrator.cnn_model # Prepare training data features_list = [] targets_list = [] for snapshot in snapshots: # Extract CNN features features = snapshot.model_inputs.get('cnn_features') if features is None: continue # Create target based on prediction accuracy if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None: # Calculate prediction error pred_range = snapshot.predicted_max_price - snapshot.predicted_min_price actual_range = snapshot.actual_max_price - snapshot.actual_min_price # Simple target: 1 if prediction was reasonably accurate, 0 otherwise range_overlap = self._calculate_range_overlap( (snapshot.predicted_min_price, snapshot.predicted_max_price), (snapshot.actual_min_price, snapshot.actual_max_price) ) target = 1 if range_overlap > 0.3 else 0 # 30% overlap threshold features_list.append(features) targets_list.append(target) if len(features_list) < self.min_batch_size: return {'error': 'Insufficient training data'} # Convert to tensors features_array = np.array(features_list, dtype=np.float32) targets_array = np.array(targets_list, dtype=np.float32) # Split into train/validation split_idx = int(len(features_array) * (1 - self.validation_split)) train_features = features_array[:split_idx] train_targets = targets_array[:split_idx] val_features = features_array[split_idx:] val_targets = targets_array[split_idx:] # Training loop device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') cnn_model.to(device) if not hasattr(cnn_model, 'optimizer'): cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=self.learning_rate) criterion = torch.nn.BCELoss() # Binary classification train_losses = [] val_accuracies = [] for epoch in range(self.epochs_per_batch): # Training step cnn_model.train() cnn_model.optimizer.zero_grad() # Forward pass inputs = torch.FloatTensor(train_features).to(device) targets = torch.FloatTensor(train_targets).to(device) # Handle different model outputs outputs = cnn_model(inputs) if isinstance(outputs, dict): if 'main_output' in outputs: logits = outputs['main_output'] else: logits = list(outputs.values())[0] else: logits = outputs # Apply sigmoid for binary classification predictions = torch.sigmoid(logits.squeeze()) loss = criterion(predictions, targets) loss.backward() cnn_model.optimizer.step() train_losses.append(loss.item()) # Validation step if len(val_features) > 0: cnn_model.eval() with torch.no_grad(): val_inputs = torch.FloatTensor(val_features).to(device) val_targets_tensor = torch.FloatTensor(val_targets).to(device) val_outputs = cnn_model(val_inputs) if isinstance(val_outputs, dict): if 'main_output' in val_outputs: val_logits = val_outputs['main_output'] else: val_logits = list(val_outputs.values())[0] else: val_logits = val_outputs val_predictions = torch.sigmoid(val_logits.squeeze()) val_binary_preds = (val_predictions > 0.5).float() val_accuracy = (val_binary_preds == val_targets_tensor).float().mean().item() val_accuracies.append(val_accuracy) # Calculate final metrics avg_train_loss = np.mean(train_losses) final_val_accuracy = val_accuracies[-1] if val_accuracies else 0.0 self.training_stats['loss_history']['cnn'].append(avg_train_loss) self.training_stats['training_accuracy']['cnn'].append(final_val_accuracy) results = { 'epochs': self.epochs_per_batch, 'final_loss': avg_train_loss, 'validation_accuracy': final_val_accuracy, 'samples_used': len(features_list) } logger.info(f"CNN training completed: loss={avg_train_loss:.4f}, val_acc={final_val_accuracy:.2f}") return results except Exception as e: logger.error(f"Error training CNN model: {e}") return {'error': str(e)} def _train_rl_model(self, snapshots: List[PredictionSnapshot], horizon_minutes: int, symbol: str) -> Dict[str, Any]: """Train RL model using prediction snapshots""" try: if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent'): return None rl_agent = self.orchestrator.rl_agent # Prepare RL training data experiences = [] for snapshot in snapshots: # Extract RL state state = snapshot.model_inputs.get('rl_state') if state is None: continue # Determine action from prediction # For min/max prediction, we can derive action from predicted direction predicted_range = snapshot.predicted_max_price - snapshot.predicted_min_price current_price = snapshot.current_price # Simple action derivation: if predicted range is mostly above current price, BUY # if mostly below, SELL, else HOLD range_center = (snapshot.predicted_min_price + snapshot.predicted_max_price) / 2 if range_center > current_price * 1.002: # 0.2% threshold action = 0 # BUY elif range_center < current_price * 0.998: action = 1 # SELL else: action = 2 # HOLD # Calculate reward based on prediction accuracy if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None: actual_center = (snapshot.actual_min_price + snapshot.actual_max_price) / 2 # Reward based on how well we predicted the price movement direction predicted_direction = 1 if range_center > current_price else -1 if range_center < current_price else 0 actual_direction = 1 if actual_center > current_price else -1 if actual_center < current_price else 0 if predicted_direction == actual_direction: reward = snapshot.confidence # Positive reward scaled by confidence else: reward = -snapshot.confidence # Negative reward scaled by confidence # Additional reward based on range accuracy range_overlap = self._calculate_range_overlap( (snapshot.predicted_min_price, snapshot.predicted_max_price), (snapshot.actual_min_price, snapshot.actual_max_price) ) reward += range_overlap * 0.5 # Bonus for accurate range prediction # Create next state (simplified) next_state = state.copy() experiences.append((state, action, reward, next_state, True)) # done=True if len(experiences) < self.min_batch_size: return {'error': 'Insufficient training data'} # Add experiences to RL agent memory experiences_added = 0 for state, action, reward, next_state, done in experiences: try: if hasattr(rl_agent, 'store_experience'): rl_agent.store_experience( state=np.array(state), action=action, reward=reward, next_state=np.array(next_state), done=done ) experiences_added += 1 elif hasattr(rl_agent, 'remember'): rl_agent.remember(np.array(state), action, reward, np.array(next_state), done) experiences_added += 1 except Exception as e: logger.debug(f"Error adding RL experience: {e}") # Perform training steps training_losses = [] if hasattr(rl_agent, 'replay') and experiences_added > 0: try: for _ in range(min(5, experiences_added // 8)): # Conservative training loss = rl_agent.replay(batch_size=min(32, experiences_added)) if loss is not None: training_losses.append(loss) except Exception as e: logger.debug(f"RL training step failed: {e}") avg_loss = np.mean(training_losses) if training_losses else 0.0 results = { 'experiences_added': experiences_added, 'training_steps': len(training_losses), 'avg_loss': avg_loss, 'samples_used': len(experiences) } logger.info(f"RL training completed: {experiences_added} experiences, avg_loss={avg_loss:.4f}") return results except Exception as e: logger.error(f"Error training RL model: {e}") return {'error': str(e)} def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float: """Calculate overlap between two price ranges (0.0 to 1.0)""" try: min1, max1 = range1 min2, max2 = range2 # Find overlap overlap_min = max(min1, min2) overlap_max = min(max1, max2) if overlap_max <= overlap_min: return 0.0 overlap_size = overlap_max - overlap_min union_size = max(max1, max2) - min(min1, min2) return overlap_size / union_size if union_size > 0 else 0.0 except Exception: return 0.0 def force_training_session(self, horizon_minutes: Optional[int] = None, symbol: Optional[str] = None) -> Dict[str, Any]: """Force a training session for specific parameters""" try: logger.info(f"Forcing training session: horizon={horizon_minutes}, symbol={symbol}") results = {} horizons = [horizon_minutes] if horizon_minutes else [1, 5, 15, 60] symbols = [symbol] if symbol else ['ETH/USDT', 'BTC/USDT'] for h in horizons: for s in symbols: try: horizon_results = self._train_horizon_models(h, s) if horizon_results: results[f"{h}m_{s}"] = horizon_results except Exception as e: logger.error(f"Error in forced training for {h}m {s}: {e}") return results except Exception as e: logger.error(f"Error in forced training session: {e}") return {'error': str(e)} def get_training_stats(self) -> Dict[str, Any]: """Get training statistics""" stats = dict(self.training_stats) stats['is_training_active'] = self.training_active # Calculate averages for model_type in ['cnn', 'rl']: if stats['training_accuracy'][model_type]: stats[f'{model_type}_avg_accuracy'] = np.mean(stats['training_accuracy'][model_type]) else: stats[f'{model_type}_avg_accuracy'] = 0.0 if stats['loss_history'][model_type]: stats[f'{model_type}_avg_loss'] = np.mean(stats['loss_history'][model_type]) else: stats[f'{model_type}_avg_loss'] = 0.0 return stats def validate_recent_predictions(self): """Validate predictions that should have outcomes available""" try: # Get pending snapshots pending_snapshots = self.snapshot_storage.get_pending_validation_snapshots() if not pending_snapshots: return logger.info(f"Validating {len(pending_snapshots)} pending predictions") # Group by symbol for efficient data access by_symbol = defaultdict(list) for snapshot in pending_snapshots: by_symbol[snapshot.symbol].append(snapshot) # Validate each symbol for symbol, snapshots in by_symbol.items(): try: self._validate_symbol_predictions(symbol, snapshots) except Exception as e: logger.error(f"Error validating predictions for {symbol}: {e}") except Exception as e: logger.error(f"Error validating recent predictions: {e}") def _validate_symbol_predictions(self, symbol: str, snapshots: List[PredictionSnapshot]): """Validate predictions for a specific symbol""" try: # Get historical data for the validation period # This is a simplified approach - in practice you'd need to get the price range # during the prediction horizon for snapshot in snapshots: try: # For now, use a simple validation approach # In a real implementation, you'd query historical data for the exact time range # and calculate actual min/max prices during the prediction horizon # Simplified: assume current price as both min and max (not accurate but functional) current_time = datetime.now() current_price = snapshot.current_price # Placeholder # Update snapshot with "outcome" self.snapshot_storage.update_snapshot_outcome( snapshot.prediction_id, current_price, # actual_min current_price, # actual_max current_time ) logger.debug(f"Validated prediction {snapshot.prediction_id}") except Exception as e: logger.error(f"Error validating snapshot {snapshot.prediction_id}: {e}") except Exception as e: logger.error(f"Error validating symbol predictions for {symbol}: {e}")