gogo2/enhanced_trading_main.py
Dobromir Popov 2f50ed920f new overhaul
2025-05-24 11:00:40 +03:00

370 lines
15 KiB
Python

"""
Enhanced Multi-Modal Trading System - Main Application
This is the main launcher for the sophisticated trading system featuring:
1. Enhanced orchestrator coordinating CNN and RL modules
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
3. Perfect move marking for CNN training with known outcomes
4. Continuous RL learning from trading action evaluations
5. Market environment adaptation and coordinated decision making
"""
import asyncio
import logging
import signal
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
import argparse
# Core components
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from models import get_model_registry
# Training components
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
# Utilities
import torch
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/enhanced_trading.log')
]
)
logger = logging.getLogger(__name__)
class EnhancedTradingSystem:
"""Main enhanced trading system coordinator"""
def __init__(self, config_path: str = None):
"""Initialize the enhanced trading system"""
self.config = get_config(config_path)
self.running = False
# Core components
self.data_provider = DataProvider(self.config)
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
self.model_registry = get_model_registry()
# Training components
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
# Models
self.cnn_models = {}
self.rl_agents = {}
# Performance tracking
self.performance_metrics = {
'decisions_made': 0,
'perfect_moves_marked': 0,
'rl_experiences_added': 0,
'training_sessions': 0
}
logger.info("Enhanced Trading System initialized")
logger.info(f"Symbols: {self.config.symbols}")
logger.info(f"Timeframes: {self.config.timeframes}")
async def initialize_models(self, load_existing: bool = True):
"""Initialize and register all models"""
logger.info("Initializing models...")
# Initialize CNN models
if load_existing:
# Try to load existing CNN model
if self.cnn_trainer.load_model('best_model.pt'):
logger.info("Loaded existing CNN model")
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
else:
logger.info("No existing CNN model found, using fresh model")
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
else:
logger.info("Creating fresh CNN model")
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
# Initialize RL agents
if load_existing:
# Try to load existing RL agents
if self.rl_trainer.load_models():
logger.info("Loaded existing RL models")
else:
logger.info("No existing RL models found, using fresh agents")
self.rl_agents = self.rl_trainer.get_agents()
# Register models with the orchestrator
for model_name, model in self.cnn_models.items():
if self.model_registry.register_model(model):
logger.info(f"Registered CNN model: {model_name}")
for symbol, agent in self.rl_agents.items():
if self.model_registry.register_model(agent):
logger.info(f"Registered RL agent for {symbol}")
# Display memory usage
memory_stats = self.model_registry.get_memory_stats()
logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / "
f"{memory_stats['total_limit_mb']:.1f}MB "
f"({memory_stats['utilization_percent']:.1f}%)")
async def start_trading_loop(self):
"""Start the main trading decision loop"""
logger.info("Starting enhanced trading loop...")
self.running = True
decision_count = 0
while self.running:
try:
# Make coordinated decisions for all symbols
decisions = await self.orchestrator.make_coordinated_decisions()
# Process decisions
for symbol, decision in decisions.items():
if decision:
decision_count += 1
self.performance_metrics['decisions_made'] += 1
logger.info(f"Trading Decision #{decision_count}")
logger.info(f"Symbol: {symbol}")
logger.info(f"Action: {decision.action}")
logger.info(f"Confidence: {decision.confidence:.3f}")
logger.info(f"Price: ${decision.price:.2f}")
logger.info(f"Quantity: {decision.quantity:.6f}")
# Log timeframe analysis
for tf_pred in decision.timeframe_analysis:
logger.info(f" {tf_pred.timeframe}: {tf_pred.action} "
f"(conf: {tf_pred.confidence:.3f})")
# Here you would integrate with actual trading execution
# For now, we just log the decision
# Evaluate past actions with RL
await self.orchestrator.evaluate_actions_with_rl()
# Check for perfect moves to mark
perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10)
if perfect_moves:
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
# Log performance metrics every 10 decisions
if decision_count % 10 == 0 and decision_count > 0:
await self._log_performance_metrics()
# Wait before next decision cycle
await asyncio.sleep(self.orchestrator.decision_frequency)
except Exception as e:
logger.error(f"Error in trading loop: {e}")
await asyncio.sleep(30) # Wait 30 seconds on error
async def start_training_loops(self):
"""Start continuous training loops"""
logger.info("Starting continuous training loops...")
# Start RL continuous learning
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
# Start periodic CNN training
cnn_task = asyncio.create_task(self._periodic_cnn_training())
return rl_task, cnn_task
async def _periodic_cnn_training(self):
"""Periodic CNN training on accumulated perfect moves"""
while self.running:
try:
# Wait for 6 hours between training sessions
await asyncio.sleep(6 * 3600)
# Check if we have enough perfect moves for training
perfect_moves = []
for symbol in self.config.symbols:
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
perfect_moves.extend(symbol_moves)
if len(perfect_moves) >= 200: # Minimum 200 perfect moves
logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves")
# Train the CNN model
training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200)
if training_report.get('training_completed'):
self.performance_metrics['training_sessions'] += 1
logger.info("CNN training completed successfully")
logger.info(f"Final validation accuracy: "
f"{training_report['final_metrics']['val_accuracy']:.4f}")
# Update the registered model
updated_model = self.cnn_trainer.get_model()
self.model_registry.unregister_model('enhanced_cnn')
self.model_registry.register_model(updated_model)
else:
logger.warning(f"CNN training failed: {training_report}")
else:
logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200")
except Exception as e:
logger.error(f"Error in periodic CNN training: {e}")
async def _log_performance_metrics(self):
"""Log system performance metrics"""
logger.info("=== SYSTEM PERFORMANCE METRICS ===")
logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}")
logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}")
logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}")
# Model registry stats
memory_stats = self.model_registry.get_memory_stats()
logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / "
f"{memory_stats['total_limit_mb']:.1f}MB")
# RL performance
rl_report = self.rl_trainer.get_performance_report()
for symbol, agent_data in rl_report['agents'].items():
logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, "
f"Experiences={agent_data['experiences_stored']}, "
f"Avg Reward={agent_data['avg_recent_reward']:.4f}")
# CNN model info
for model_name, model in self.cnn_models.items():
logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, "
f"Device={model.device}")
async def shutdown(self):
"""Graceful shutdown of the system"""
logger.info("Shutting down Enhanced Trading System...")
self.running = False
# Save models
logger.info("Saving models...")
self.cnn_trainer._save_model('shutdown_model.pt')
self.rl_trainer._save_all_models()
# Clean up memory
self.model_registry.cleanup_all_models()
# Generate final reports
logger.info("Generating final reports...")
# CNN training plots
if self.cnn_trainer.training_history['train_loss']:
self.cnn_trainer._plot_training_history()
# RL training plots
self.rl_trainer.plot_training_metrics()
logger.info("Enhanced Trading System shutdown complete")
def setup_signal_handlers(trading_system: EnhancedTradingSystem):
"""Setup signal handlers for graceful shutdown"""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating shutdown...")
asyncio.create_task(trading_system.shutdown())
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async def main():
"""Main application entry point"""
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
parser.add_argument('--config', type=str, help='Configuration file path')
parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'],
default='trade', help='Operation mode')
parser.add_argument('--load-models', action='store_true', default=True,
help='Load existing models')
parser.add_argument('--no-load-models', action='store_false', dest='load_models',
help="Don't load existing models")
args = parser.parse_args()
# Create logs directory
Path('logs').mkdir(exist_ok=True)
logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===")
logger.info(f"Mode: {args.mode}")
logger.info(f"Load existing models: {args.load_models}")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
# Initialize trading system
trading_system = EnhancedTradingSystem(args.config)
# Setup signal handlers
setup_signal_handlers(trading_system)
try:
# Initialize models
await trading_system.initialize_models(load_existing=args.load_models)
if args.mode == 'trade':
# Start training loops
rl_task, cnn_task = await trading_system.start_training_loops()
# Start main trading loop
trading_task = asyncio.create_task(trading_system.start_trading_loop())
# Wait for any task to complete (or error)
done, pending = await asyncio.wait(
[trading_task, rl_task, cnn_task],
return_when=asyncio.FIRST_COMPLETED
)
# Cancel remaining tasks
for task in pending:
task.cancel()
elif args.mode == 'train':
# Training-only mode
logger.info("Running in training-only mode...")
# Train CNN if we have perfect moves
perfect_moves = []
for symbol in trading_system.config.symbols:
symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol)
perfect_moves.extend(symbol_moves)
if len(perfect_moves) >= 100:
logger.info(f"Training CNN on {len(perfect_moves)} perfect moves")
training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100)
logger.info(f"CNN training report: {training_report}")
else:
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}")
# Train RL agents if they have experiences
await trading_system.rl_trainer._train_all_agents()
elif args.mode == 'backtest':
# Backtesting mode
logger.info("Backtesting mode not implemented yet")
return
except KeyboardInterrupt:
logger.info("Received keyboard interrupt")
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
finally:
await trading_system.shutdown()
if __name__ == "__main__":
# Run the main application
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application terminated by user")
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)