""" Enhanced Multi-Modal Trading System - Main Application This is the main launcher for the sophisticated trading system featuring: 1. Enhanced orchestrator coordinating CNN and RL modules 2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions 3. Perfect move marking for CNN training with known outcomes 4. Continuous RL learning from trading action evaluations 5. Market environment adaptation and coordinated decision making """ import asyncio import logging import signal import sys from datetime import datetime from pathlib import Path from typing import Dict, Any, Optional import argparse # Core components from core.config import get_config from core.data_provider import DataProvider from core.enhanced_orchestrator import EnhancedTradingOrchestrator from models import get_model_registry # Training components from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent # Utilities import torch # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler('logs/enhanced_trading.log') ] ) logger = logging.getLogger(__name__) class EnhancedTradingSystem: """Main enhanced trading system coordinator""" def __init__(self, config_path: str = None): """Initialize the enhanced trading system""" self.config = get_config(config_path) self.running = False # Core components self.data_provider = DataProvider(self.config) self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) self.model_registry = get_model_registry() # Training components self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator) self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator) # Models self.cnn_models = {} self.rl_agents = {} # Performance tracking self.performance_metrics = { 'decisions_made': 0, 'perfect_moves_marked': 0, 'rl_experiences_added': 0, 'training_sessions': 0 } logger.info("Enhanced Trading System initialized") logger.info(f"Symbols: {self.config.symbols}") logger.info(f"Timeframes: {self.config.timeframes}") async def initialize_models(self, load_existing: bool = True): """Initialize and register all models""" logger.info("Initializing models...") # Initialize CNN models if load_existing: # Try to load existing CNN model if self.cnn_trainer.load_model('best_model.pt'): logger.info("Loaded existing CNN model") self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model() else: logger.info("No existing CNN model found, using fresh model") self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model() else: logger.info("Creating fresh CNN model") self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model() # Initialize RL agents if load_existing: # Try to load existing RL agents if self.rl_trainer.load_models(): logger.info("Loaded existing RL models") else: logger.info("No existing RL models found, using fresh agents") self.rl_agents = self.rl_trainer.get_agents() # Register models with the orchestrator for model_name, model in self.cnn_models.items(): if self.model_registry.register_model(model): logger.info(f"Registered CNN model: {model_name}") for symbol, agent in self.rl_agents.items(): if self.model_registry.register_model(agent): logger.info(f"Registered RL agent for {symbol}") # Display memory usage memory_stats = self.model_registry.get_memory_stats() logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / " f"{memory_stats['total_limit_mb']:.1f}MB " f"({memory_stats['utilization_percent']:.1f}%)") async def start_trading_loop(self): """Start the main trading decision loop""" logger.info("Starting enhanced trading loop...") self.running = True decision_count = 0 while self.running: try: # Make coordinated decisions for all symbols decisions = await self.orchestrator.make_coordinated_decisions() # Process decisions for symbol, decision in decisions.items(): if decision: decision_count += 1 self.performance_metrics['decisions_made'] += 1 logger.info(f"Trading Decision #{decision_count}") logger.info(f"Symbol: {symbol}") logger.info(f"Action: {decision.action}") logger.info(f"Confidence: {decision.confidence:.3f}") logger.info(f"Price: ${decision.price:.2f}") logger.info(f"Quantity: {decision.quantity:.6f}") # Log timeframe analysis for tf_pred in decision.timeframe_analysis: logger.info(f" {tf_pred.timeframe}: {tf_pred.action} " f"(conf: {tf_pred.confidence:.3f})") # Here you would integrate with actual trading execution # For now, we just log the decision # Evaluate past actions with RL await self.orchestrator.evaluate_actions_with_rl() # Check for perfect moves to mark perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10) if perfect_moves: self.performance_metrics['perfect_moves_marked'] = len(perfect_moves) # Log performance metrics every 10 decisions if decision_count % 10 == 0 and decision_count > 0: await self._log_performance_metrics() # Wait before next decision cycle await asyncio.sleep(self.orchestrator.decision_frequency) except Exception as e: logger.error(f"Error in trading loop: {e}") await asyncio.sleep(30) # Wait 30 seconds on error async def start_training_loops(self): """Start continuous training loops""" logger.info("Starting continuous training loops...") # Start RL continuous learning rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop()) # Start periodic CNN training cnn_task = asyncio.create_task(self._periodic_cnn_training()) return rl_task, cnn_task async def _periodic_cnn_training(self): """Periodic CNN training on accumulated perfect moves""" while self.running: try: # Wait for 6 hours between training sessions await asyncio.sleep(6 * 3600) # Check if we have enough perfect moves for training perfect_moves = [] for symbol in self.config.symbols: symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol) perfect_moves.extend(symbol_moves) if len(perfect_moves) >= 200: # Minimum 200 perfect moves logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves") # Train the CNN model training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200) if training_report.get('training_completed'): self.performance_metrics['training_sessions'] += 1 logger.info("CNN training completed successfully") logger.info(f"Final validation accuracy: " f"{training_report['final_metrics']['val_accuracy']:.4f}") # Update the registered model updated_model = self.cnn_trainer.get_model() self.model_registry.unregister_model('enhanced_cnn') self.model_registry.register_model(updated_model) else: logger.warning(f"CNN training failed: {training_report}") else: logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200") except Exception as e: logger.error(f"Error in periodic CNN training: {e}") async def _log_performance_metrics(self): """Log system performance metrics""" logger.info("=== SYSTEM PERFORMANCE METRICS ===") logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}") logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}") logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}") # Model registry stats memory_stats = self.model_registry.get_memory_stats() logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / " f"{memory_stats['total_limit_mb']:.1f}MB") # RL performance rl_report = self.rl_trainer.get_performance_report() for symbol, agent_data in rl_report['agents'].items(): logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, " f"Experiences={agent_data['experiences_stored']}, " f"Avg Reward={agent_data['avg_recent_reward']:.4f}") # CNN model info for model_name, model in self.cnn_models.items(): logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, " f"Device={model.device}") async def shutdown(self): """Graceful shutdown of the system""" logger.info("Shutting down Enhanced Trading System...") self.running = False # Save models logger.info("Saving models...") self.cnn_trainer._save_model('shutdown_model.pt') self.rl_trainer._save_all_models() # Clean up memory self.model_registry.cleanup_all_models() # Generate final reports logger.info("Generating final reports...") # CNN training plots if self.cnn_trainer.training_history['train_loss']: self.cnn_trainer._plot_training_history() # RL training plots self.rl_trainer.plot_training_metrics() logger.info("Enhanced Trading System shutdown complete") def setup_signal_handlers(trading_system: EnhancedTradingSystem): """Setup signal handlers for graceful shutdown""" def signal_handler(signum, frame): logger.info(f"Received signal {signum}, initiating shutdown...") asyncio.create_task(trading_system.shutdown()) sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) async def main(): """Main application entry point""" parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System') parser.add_argument('--config', type=str, help='Configuration file path') parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'], default='trade', help='Operation mode') parser.add_argument('--load-models', action='store_true', default=True, help='Load existing models') parser.add_argument('--no-load-models', action='store_false', dest='load_models', help="Don't load existing models") args = parser.parse_args() # Create logs directory Path('logs').mkdir(exist_ok=True) logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===") logger.info(f"Mode: {args.mode}") logger.info(f"Load existing models: {args.load_models}") logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") # Initialize trading system trading_system = EnhancedTradingSystem(args.config) # Setup signal handlers setup_signal_handlers(trading_system) try: # Initialize models await trading_system.initialize_models(load_existing=args.load_models) if args.mode == 'trade': # Start training loops rl_task, cnn_task = await trading_system.start_training_loops() # Start main trading loop trading_task = asyncio.create_task(trading_system.start_trading_loop()) # Wait for any task to complete (or error) done, pending = await asyncio.wait( [trading_task, rl_task, cnn_task], return_when=asyncio.FIRST_COMPLETED ) # Cancel remaining tasks for task in pending: task.cancel() elif args.mode == 'train': # Training-only mode logger.info("Running in training-only mode...") # Train CNN if we have perfect moves perfect_moves = [] for symbol in trading_system.config.symbols: symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol) perfect_moves.extend(symbol_moves) if len(perfect_moves) >= 100: logger.info(f"Training CNN on {len(perfect_moves)} perfect moves") training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100) logger.info(f"CNN training report: {training_report}") else: logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}") # Train RL agents if they have experiences await trading_system.rl_trainer._train_all_agents() elif args.mode == 'backtest': # Backtesting mode logger.info("Backtesting mode not implemented yet") return except KeyboardInterrupt: logger.info("Received keyboard interrupt") except Exception as e: logger.error(f"Unexpected error: {e}", exc_info=True) finally: await trading_system.shutdown() if __name__ == "__main__": # Run the main application try: asyncio.run(main()) except KeyboardInterrupt: logger.info("Application terminated by user") except Exception as e: logger.error(f"Fatal error: {e}", exc_info=True) sys.exit(1)