#!/usr/bin/env python3 """ Unified Training Runner CRITICAL POLICY: NO SYNTHETIC DATA ALLOWED This module MUST ONLY use real market data from exchanges. NEVER use np.random.*, mock/fake/synthetic data, or placeholder values. If data is unavailable: return None/0/empty, log errors, raise exceptions. See: reports/REAL_MARKET_DATA_POLICY.md Consolidated training system supporting both realtime and backtesting modes. Modes: 1. REALTIME: Live market data training with continuous learning 2. BACKTEST: Historical data with sliding window simulation for fast training Features: - Multi-horizon predictions (1m, 5m, 15m, 60m) - CNN, DQN, and COB RL model training - Checkpoint management with model rotation - Performance tracking and reporting - Resumable training sessions """ import logging import time import json import argparse from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Any, Optional from collections import deque import asyncio # Core components from core.data_provider import DataProvider from core.orchestrator import TradingOrchestrator from core.multi_horizon_backtester import MultiHorizonBacktester from core.multi_horizon_prediction_manager import MultiHorizonPredictionManager from core.prediction_snapshot_storage import PredictionSnapshotStorage from core.multi_horizon_trainer import MultiHorizonTrainer # Model management from NN.training.model_manager import create_model_manager from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('logs/training.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class UnifiedTrainingRunner: """Unified training system supporting both realtime and backtesting modes""" def __init__(self, mode: str = "realtime", symbol: str = "ETH/USDT"): """ Initialize the unified training runner Args: mode: "realtime" for live training or "backtest" for historical training symbol: Trading symbol to train on """ self.mode = mode self.symbol = symbol self.start_time = datetime.now() logger.info(f"Initializing Unified Training Runner - Mode: {mode.upper()}") # Initialize core components self.data_provider = DataProvider() self.orchestrator = TradingOrchestrator( data_provider=self.data_provider, enhanced_rl_training=True ) # Initialize training components self.backtester = MultiHorizonBacktester(self.data_provider) self.prediction_manager = MultiHorizonPredictionManager( data_provider=self.data_provider ) self.snapshot_storage = PredictionSnapshotStorage() self.trainer = MultiHorizonTrainer( orchestrator=self.orchestrator, snapshot_storage=self.snapshot_storage ) # Initialize enhanced real-time training (used in both modes) self.enhanced_training = None if hasattr(self.orchestrator, 'enhanced_training_system'): self.enhanced_training = self.orchestrator.enhanced_training_system # Model checkpoint manager self.checkpoint_manager = create_model_manager() # Training configuration self.config = { 'realtime': { 'checkpoint_interval_minutes': 30, 'backtest_interval_minutes': 60, 'performance_check_minutes': 15 }, 'backtest': { 'window_size_hours': 24, 'step_size_hours': 1, 'batch_size': 64, 'save_interval_hours': 2 } } # Performance tracking self.metrics = { 'training_sessions': [], 'backtest_results': [], 'model_checkpoints': [], 'prediction_accuracy': deque(maxlen=1000), 'training_losses': {'cnn': [], 'dqn': [], 'cob_rl': []} } # Training state self.is_running = False self.progress_file = Path('training_progress.json') logger.info(f"Unified Training Runner initialized for {symbol}") logger.info(f"Mode: {mode}, Enhanced Training: {self.enhanced_training is not None}") def run_realtime_training(self, duration_hours: Optional[float] = None): """ Run continuous real-time training on live market data Args: duration_hours: How long to train (None = indefinite) """ logger.info("=" * 70) logger.info("STARTING REALTIME TRAINING") logger.info("=" * 70) logger.info(f"Duration: {'indefinite' if duration_hours is None else f'{duration_hours} hours'}") self.is_running = True config = self.config['realtime'] last_checkpoint = time.time() last_backtest = time.time() last_perf_check = time.time() try: # Start enhanced training if available if self.enhanced_training and hasattr(self.orchestrator, 'start_enhanced_training'): self.orchestrator.start_enhanced_training() logger.info("Enhanced real-time training started") # Start multi-horizon prediction and training self.prediction_manager.start() self.trainer.start() logger.info("Multi-horizon prediction and training started") while self.is_running: current_time = time.time() elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600 # Check duration limit if duration_hours and elapsed_hours >= duration_hours: logger.info(f"Training duration completed: {elapsed_hours:.1f} hours") break # Periodic checkpoint save if current_time - last_checkpoint > config['checkpoint_interval_minutes'] * 60: self._save_checkpoint() last_checkpoint = current_time # Periodic backtest validation if current_time - last_backtest > config['backtest_interval_minutes'] * 60: accuracy = self._run_backtest_validation() if accuracy is not None: self.metrics['prediction_accuracy'].append(accuracy) logger.info(f"Backtest accuracy at {elapsed_hours:.1f}h: {accuracy:.3%}") last_backtest = current_time # Performance check if current_time - last_perf_check > config['performance_check_minutes'] * 60: self._log_performance_metrics() last_perf_check = current_time # Sleep to reduce CPU usage time.sleep(60) except KeyboardInterrupt: logger.info("Training interrupted by user") finally: self._cleanup_training() self._generate_final_report() def run_backtest_training(self, start_date: datetime, end_date: datetime): """ Run fast backtesting with sliding window for bulk training Args: start_date: Start date for backtesting end_date: End date for backtesting """ logger.info("=" * 70) logger.info("STARTING BACKTEST TRAINING") logger.info("=" * 70) logger.info(f"Period: {start_date} to {end_date}") config = self.config['backtest'] window_hours = config['window_size_hours'] step_hours = config['step_size_hours'] current_date = start_date batch_count = 0 total_samples = 0 try: while current_date < end_date: window_end = current_date + timedelta(hours=window_hours) if window_end > end_date: break batch_count += 1 logger.info(f"Batch {batch_count}: {current_date} to {window_end}") # Fetch historical data for window data = self._fetch_window_data(current_date, window_end) if data and len(data) > 0: # Simulate real-time data flow through sliding window samples_trained = self._train_on_window(data) total_samples += samples_trained logger.info(f"Trained on {samples_trained} samples in window") # Save checkpoint periodically elapsed_hours = (window_end - start_date).total_seconds() / 3600 if elapsed_hours % config['save_interval_hours'] == 0: self._save_checkpoint() logger.info(f"Checkpoint saved at {elapsed_hours:.1f}h") # Move window forward current_date += timedelta(hours=step_hours) logger.info(f"Backtest training complete: {batch_count} batches, {total_samples} samples") except Exception as e: logger.error(f"Error in backtest training: {e}") raise finally: self._generate_final_report() def _fetch_window_data(self, start: datetime, end: datetime) -> List[Dict]: """Fetch historical data for a time window""" try: # Fetch from data provider with real market data data = self.data_provider.get_historical_data( symbol=self.symbol, timeframe='1m', start_time=start, end_time=end ) if data is None or len(data) == 0: logger.warning(f"No data available for {start} to {end}") return [] return data except Exception as e: logger.error(f"Error fetching window data: {e}") return [] def _train_on_window(self, data: List[Dict]) -> int: """ Train models on a sliding window of data Args: data: List of market data points Returns: Number of samples trained on """ samples_trained = 0 # Simulate real-time flow through data for i in range(len(data) - 1): current = data[i] next_data = data[i + 1] # Create prediction snapshot snapshot = { 'timestamp': current.get('timestamp'), 'price': current.get('close', 0), 'volume': current.get('volume', 0), 'symbol': self.symbol } # Store snapshot for later training self.snapshot_storage.store_snapshot(snapshot) # When we have outcome, train the models if i > 0: # Need previous snapshot for outcome prev_snapshot = data[i - 1] outcome = { 'actual_price': current.get('close', 0), 'timestamp': current.get('timestamp') } # Train via multi-horizon trainer self.trainer.train_on_outcome(prev_snapshot, outcome) samples_trained += 1 return samples_trained def _run_backtest_validation(self) -> Optional[float]: """Run backtest on recent data to validate model performance""" try: end_date = datetime.now() start_date = end_date - timedelta(hours=24) results = self.backtester.run_backtest( symbol=self.symbol, start_date=start_date, end_date=end_date, horizons=[1, 5, 15, 60] # minutes ) if results and 'accuracy' in results: return results['accuracy'] return None except Exception as e: logger.error(f"Error in backtest validation: {e}") return None def _save_checkpoint(self): """Save model checkpoints with rotation""" try: checkpoint_data = { 'timestamp': datetime.now().isoformat(), 'mode': self.mode, 'elapsed_hours': (datetime.now() - self.start_time).total_seconds() / 3600, 'metrics': { 'prediction_accuracy': list(self.metrics['prediction_accuracy'])[-10:], 'total_training_samples': sum( len(losses) for losses in self.metrics['training_losses'].values() ) } } # Use model manager for checkpoint rotation (keeps best 5) self.checkpoint_manager.save_checkpoint( model=self.orchestrator, metadata=checkpoint_data ) self.metrics['model_checkpoints'].append(checkpoint_data) logger.info("Checkpoint saved successfully") except Exception as e: logger.error(f"Error saving checkpoint: {e}") def _log_performance_metrics(self): """Log current performance metrics""" elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600 avg_accuracy = 0 if self.metrics['prediction_accuracy']: avg_accuracy = sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy']) logger.info("=" * 50) logger.info(f"Performance Metrics @ {elapsed_hours:.1f}h") logger.info(f" Avg Prediction Accuracy: {avg_accuracy:.3%}") logger.info(f" Total Checkpoints: {len(self.metrics['model_checkpoints'])}") logger.info(f" CNN Training Samples: {len(self.metrics['training_losses']['cnn'])}") logger.info(f" DQN Training Samples: {len(self.metrics['training_losses']['dqn'])}") logger.info("=" * 50) def _cleanup_training(self): """Clean up training resources""" logger.info("Cleaning up training resources...") # Stop prediction and training if hasattr(self.prediction_manager, 'stop'): self.prediction_manager.stop() if hasattr(self.trainer, 'stop'): self.trainer.stop() # Save final checkpoint self._save_checkpoint() logger.info("Training cleanup complete") def _generate_final_report(self): """Generate final training report""" report = { 'mode': self.mode, 'symbol': self.symbol, 'start_time': self.start_time.isoformat(), 'end_time': datetime.now().isoformat(), 'duration_hours': (datetime.now() - self.start_time).total_seconds() / 3600, 'metrics': { 'total_checkpoints': len(self.metrics['model_checkpoints']), 'total_backtest_runs': len(self.metrics['backtest_results']), 'final_accuracy': list(self.metrics['prediction_accuracy'])[-1] if self.metrics['prediction_accuracy'] else 0, 'avg_accuracy': sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy']) if self.metrics['prediction_accuracy'] else 0 } } report_file = Path(f'training_report_{self.mode}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json') with open(report_file, 'w') as f: json.dump(report, f, indent=2) logger.info("=" * 70) logger.info("TRAINING COMPLETE") logger.info("=" * 70) logger.info(f"Mode: {self.mode}") logger.info(f"Duration: {report['duration_hours']:.2f} hours") logger.info(f"Final Accuracy: {report['metrics']['final_accuracy']:.3%}") logger.info(f"Avg Accuracy: {report['metrics']['avg_accuracy']:.3%}") logger.info(f"Report saved to: {report_file}") logger.info("=" * 70) def main(): """Main entry point for training runner""" parser = argparse.ArgumentParser(description="Unified Training Runner") parser.add_argument( '--mode', type=str, choices=['realtime', 'backtest'], default='realtime', help='Training mode: realtime or backtest' ) parser.add_argument( '--symbol', type=str, default='ETH/USDT', help='Trading symbol' ) parser.add_argument( '--duration', type=float, default=None, help='Training duration in hours (realtime mode only)' ) parser.add_argument( '--start-date', type=str, default=None, help='Start date for backtest (YYYY-MM-DD)' ) parser.add_argument( '--end-date', type=str, default=None, help='End date for backtest (YYYY-MM-DD)' ) args = parser.parse_args() # Create training runner runner = UnifiedTrainingRunner(mode=args.mode, symbol=args.symbol) if args.mode == 'realtime': runner.run_realtime_training(duration_hours=args.duration) else: # backtest if not args.start_date or not args.end_date: logger.error("Backtest mode requires --start-date and --end-date") return start = datetime.strptime(args.start_date, '%Y-%m-%d') end = datetime.strptime(args.end_date, '%Y-%m-%d') runner.run_backtest_training(start_date=start, end_date=end) if __name__ == '__main__': main()