#!/usr/bin/env python3 """ Comprehensive Checkpoint Management Integration This script demonstrates how to integrate the checkpoint management system across all training pipelines in the gogo2 project. Features: - DQN Agent training with automatic checkpointing - CNN Model training with checkpoint management - ExtremaTrainer with checkpoint persistence - NegativeCaseTrainer with checkpoint integration - Unified training orchestration with checkpoint coordination """ import asyncio import logging import time import signal import sys import numpy as np from datetime import datetime from pathlib import Path from typing import Dict, Any, List # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('logs/checkpoint_integration.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # Import checkpoint management from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats from utils.training_integration import get_training_integration # Import training components from NN.models.dqn_agent import DQNAgent from NN.models.cnn_model import CNNModelTrainer, create_enhanced_cnn_model from core.extrema_trainer import ExtremaTrainer from core.negative_case_trainer import NegativeCaseTrainer from core.data_provider import DataProvider from core.config import get_config class CheckpointIntegratedTrainingSystem: """Unified training system with comprehensive checkpoint management""" def __init__(self): """Initialize the checkpoint-integrated training system""" self.config = get_config() self.running = False # Checkpoint management self.checkpoint_manager = get_checkpoint_manager() self.training_integration = get_training_integration() # Data provider self.data_provider = DataProvider( symbols=['ETH/USDT', 'BTC/USDT'], timeframes=['1s', '1m', '1h', '1d'] ) # Training components with checkpoint management self.dqn_agent = None self.cnn_trainer = None self.extrema_trainer = None self.negative_case_trainer = None # Training statistics self.training_stats = { 'start_time': None, 'total_training_sessions': 0, 'checkpoints_saved': 0, 'models_loaded': 0, 'best_performances': {} } logger.info("Checkpoint-Integrated Training System initialized") async def initialize_components(self): """Initialize all training components with checkpoint management""" try: logger.info("Initializing training components with checkpoint management...") # Initialize data provider await self.data_provider.start_real_time_streaming() logger.info("Data provider streaming started") # Initialize DQN Agent with checkpoint management logger.info("Initializing DQN Agent with checkpoints...") self.dqn_agent = DQNAgent( state_shape=(100,), # Example state shape n_actions=3, model_name="integrated_dqn_agent", enable_checkpoints=True ) logger.info("✅ DQN Agent initialized with checkpoint management") # Initialize CNN Model with checkpoint management logger.info("Initializing CNN Model with checkpoints...") cnn_model, self.cnn_trainer = create_enhanced_cnn_model( input_size=60, feature_dim=50, output_size=3 ) # Update trainer with checkpoint management self.cnn_trainer.model_name = "integrated_cnn_model" self.cnn_trainer.enable_checkpoints = True self.cnn_trainer.training_integration = self.training_integration logger.info("✅ CNN Model initialized with checkpoint management") # Initialize ExtremaTrainer with checkpoint management logger.info("Initializing ExtremaTrainer with checkpoints...") self.extrema_trainer = ExtremaTrainer( data_provider=self.data_provider, symbols=['ETH/USDT', 'BTC/USDT'], model_name="integrated_extrema_trainer", enable_checkpoints=True ) await self.extrema_trainer.initialize_context_data() logger.info("✅ ExtremaTrainer initialized with checkpoint management") # Initialize NegativeCaseTrainer with checkpoint management logger.info("Initializing NegativeCaseTrainer with checkpoints...") self.negative_case_trainer = NegativeCaseTrainer( model_name="integrated_negative_case_trainer", enable_checkpoints=True ) logger.info("✅ NegativeCaseTrainer initialized with checkpoint management") # Load existing checkpoints for all components self.training_stats['models_loaded'] = await self._load_all_checkpoints() logger.info("All training components initialized successfully") except Exception as e: logger.error(f"Error initializing components: {e}") raise async def _load_all_checkpoints(self) -> int: """Load checkpoints for all training components""" loaded_count = 0 try: # DQN Agent checkpoint loading is handled in __init__ if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0: loaded_count += 1 logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}") # CNN Trainer checkpoint loading is handled in __init__ if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0: loaded_count += 1 logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}") # ExtremaTrainer checkpoint loading is handled in __init__ if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0: loaded_count += 1 logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}") # NegativeCaseTrainer checkpoint loading is handled in __init__ if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0: loaded_count += 1 logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}") return loaded_count except Exception as e: logger.error(f"Error loading checkpoints: {e}") return 0 async def run_integrated_training_loop(self): """Run the integrated training loop with checkpoint coordination""" logger.info("Starting integrated training loop with checkpoint management...") self.running = True self.training_stats['start_time'] = datetime.now() training_cycle = 0 try: while self.running: training_cycle += 1 cycle_start = time.time() logger.info(f"=== Training Cycle {training_cycle} ===") # DQN Training dqn_results = await self._train_dqn_agent() # CNN Training cnn_results = await self._train_cnn_model() # Extrema Detection Training extrema_results = await self._train_extrema_detector() # Negative Case Training (runs in background) negative_results = await self._process_negative_cases() # Coordinate checkpoint saving await self._coordinate_checkpoint_saving( dqn_results, cnn_results, extrema_results, negative_results ) # Update statistics self.training_stats['total_training_sessions'] += 1 # Log cycle summary cycle_duration = time.time() - cycle_start logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s") # Wait before next cycle await asyncio.sleep(60) # 1-minute cycles except KeyboardInterrupt: logger.info("Training interrupted by user") except Exception as e: logger.error(f"Error in training loop: {e}") finally: await self.shutdown() async def _train_dqn_agent(self) -> Dict[str, Any]: """Train DQN agent with automatic checkpointing""" try: if not self.dqn_agent: return {'status': 'skipped', 'reason': 'no_agent'} # Simulate DQN training episode episode_reward = 0.0 # Add some training experiences (simulate real training) for _ in range(10): # Simulate 10 training steps state = np.random.randn(100).astype(np.float32) action = np.random.randint(0, 3) reward = np.random.randn() * 0.1 next_state = np.random.randn(100).astype(np.float32) done = np.random.random() < 0.1 self.dqn_agent.remember(state, action, reward, next_state, done) episode_reward += reward # Train if enough experiences loss = 0.0 if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size: loss = self.dqn_agent.replay() # Save checkpoint (automatic based on performance) checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward) if checkpoint_saved: self.training_stats['checkpoints_saved'] += 1 return { 'status': 'completed', 'episode_reward': episode_reward, 'loss': loss, 'checkpoint_saved': checkpoint_saved, 'episode': self.dqn_agent.episode_count } except Exception as e: logger.error(f"Error training DQN agent: {e}") return {'status': 'error', 'error': str(e)} async def _train_cnn_model(self) -> Dict[str, Any]: """Train CNN model with automatic checkpointing""" try: if not self.cnn_trainer: return {'status': 'skipped', 'reason': 'no_trainer'} # Simulate CNN training step import torch import numpy as np batch_size = 32 input_size = 60 feature_dim = 50 # Generate synthetic training data x = torch.randn(batch_size, input_size, feature_dim) y = torch.randint(0, 3, (batch_size,)) # Training step results = self.cnn_trainer.train_step(x, y) # Simulate validation val_x = torch.randn(16, input_size, feature_dim) val_y = torch.randint(0, 3, (16,)) val_results = self.cnn_trainer.train_step(val_x, val_y) # Save checkpoint (automatic based on performance) checkpoint_saved = self.cnn_trainer.save_checkpoint( train_accuracy=results.get('accuracy', 0.5), val_accuracy=val_results.get('accuracy', 0.5), train_loss=results.get('total_loss', 1.0), val_loss=val_results.get('total_loss', 1.0) ) if checkpoint_saved: self.training_stats['checkpoints_saved'] += 1 return { 'status': 'completed', 'train_accuracy': results.get('accuracy', 0.5), 'val_accuracy': val_results.get('accuracy', 0.5), 'train_loss': results.get('total_loss', 1.0), 'val_loss': val_results.get('total_loss', 1.0), 'checkpoint_saved': checkpoint_saved, 'epoch': self.cnn_trainer.epoch_count } except Exception as e: logger.error(f"Error training CNN model: {e}") return {'status': 'error', 'error': str(e)} async def _train_extrema_detector(self) -> Dict[str, Any]: """Train extrema detector with automatic checkpointing""" try: if not self.extrema_trainer: return {'status': 'skipped', 'reason': 'no_trainer'} # Update context data and detect extrema update_results = self.extrema_trainer.update_context_data() # Get training data extrema_data = self.extrema_trainer.get_extrema_training_data(count=10) # Simulate training accuracy improvement if extrema_data: self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data) self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2 self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2 # Save checkpoint (automatic based on performance) checkpoint_saved = self.extrema_trainer.save_checkpoint() if checkpoint_saved: self.training_stats['checkpoints_saved'] += 1 return { 'status': 'completed', 'extrema_detected': len(extrema_data), 'context_updates': sum(1 for success in update_results.values() if success), 'checkpoint_saved': checkpoint_saved, 'session': self.extrema_trainer.training_session_count } except Exception as e: logger.error(f"Error training extrema detector: {e}") return {'status': 'error', 'error': str(e)} async def _process_negative_cases(self) -> Dict[str, Any]: """Process negative cases with automatic checkpointing""" try: if not self.negative_case_trainer: return {'status': 'skipped', 'reason': 'no_trainer'} # Simulate adding a negative case if np.random.random() < 0.1: # 10% chance of negative case trade_info = { 'symbol': 'ETH/USDT', 'action': 'BUY', 'price': 2000.0, 'pnl': -50.0, # Loss 'value': 1000.0, 'confidence': 0.7, 'timestamp': datetime.now() } market_data = { 'exit_price': 1950.0, 'state_before': {}, 'state_after': {}, 'tick_data': [], 'technical_indicators': {} } case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data) # Simulate loss improvement loss_improvement = np.random.random() * 0.1 # Save checkpoint (automatic based on performance) checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement) if checkpoint_saved: self.training_stats['checkpoints_saved'] += 1 return { 'status': 'completed', 'case_added': case_id, 'loss_improvement': loss_improvement, 'checkpoint_saved': checkpoint_saved, 'session': self.negative_case_trainer.training_session_count } else: return {'status': 'no_cases'} except Exception as e: logger.error(f"Error processing negative cases: {e}") return {'status': 'error', 'error': str(e)} async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict, extrema_results: Dict, negative_results: Dict): """Coordinate checkpoint saving across all components""" try: # Count successful checkpoints checkpoints_saved = sum([ dqn_results.get('checkpoint_saved', False), cnn_results.get('checkpoint_saved', False), extrema_results.get('checkpoint_saved', False), negative_results.get('checkpoint_saved', False) ]) if checkpoints_saved > 0: logger.info(f"Saved {checkpoints_saved} checkpoints this cycle") # Update best performances if 'episode_reward' in dqn_results: current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf')) if dqn_results['episode_reward'] > current_best: self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward'] if 'val_accuracy' in cnn_results: current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0) if cnn_results['val_accuracy'] > current_best: self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy'] # Log checkpoint statistics every 10 cycles if self.training_stats['total_training_sessions'] % 10 == 0: await self._log_checkpoint_statistics() except Exception as e: logger.error(f"Error coordinating checkpoint saving: {e}") async def _log_checkpoint_statistics(self): """Log comprehensive checkpoint statistics""" try: stats = get_checkpoint_stats() logger.info("=== Checkpoint Statistics ===") logger.info(f"Total checkpoints: {stats['total_checkpoints']}") logger.info(f"Total size: {stats['total_size_mb']:.2f} MB") logger.info(f"Models managed: {len(stats['models'])}") for model_name, model_stats in stats['models'].items(): logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, " f"{model_stats['total_size_mb']:.2f} MB, " f"best: {model_stats['best_performance']:.4f}") logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}") logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}") logger.info(f"Best performances: {self.training_stats['best_performances']}") except Exception as e: logger.error(f"Error logging checkpoint statistics: {e}") async def shutdown(self): """Shutdown the training system and save final checkpoints""" logger.info("Shutting down checkpoint-integrated training system...") self.running = False try: # Force save checkpoints for all components if self.dqn_agent: self.dqn_agent.save_checkpoint(0.0, force_save=True) if self.cnn_trainer: self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True) if self.extrema_trainer: self.extrema_trainer.save_checkpoint(force_save=True) if self.negative_case_trainer: self.negative_case_trainer.save_checkpoint(force_save=True) # Final statistics await self._log_checkpoint_statistics() logger.info("Checkpoint-integrated training system shutdown complete") except Exception as e: logger.error(f"Error during shutdown: {e}") async def main(): """Main function to run the checkpoint-integrated training system""" logger.info("🚀 Starting Checkpoint-Integrated Training System") # Create and initialize the training system training_system = CheckpointIntegratedTrainingSystem() # Setup signal handlers for graceful shutdown def signal_handler(signum, frame): logger.info("Received shutdown signal") asyncio.create_task(training_system.shutdown()) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: # Initialize components await training_system.initialize_components() # Run the integrated training loop await training_system.run_integrated_training_loop() except Exception as e: logger.error(f"Error in main: {e}") raise finally: await training_system.shutdown() logger.info("✅ Checkpoint management integration complete!") logger.info("All training pipelines now support automatic checkpointing") if __name__ == "__main__": # Ensure logs directory exists Path("logs").mkdir(exist_ok=True) # Run the checkpoint-integrated training system asyncio.run(main())