new overhaul
This commit is contained in:
370
enhanced_trading_main.py
Normal file
370
enhanced_trading_main.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Enhanced Multi-Modal Trading System - Main Application
|
||||
|
||||
This is the main launcher for the sophisticated trading system featuring:
|
||||
1. Enhanced orchestrator coordinating CNN and RL modules
|
||||
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
|
||||
3. Perfect move marking for CNN training with known outcomes
|
||||
4. Continuous RL learning from trading action evaluations
|
||||
5. Market environment adaptation and coordinated decision making
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import argparse
|
||||
|
||||
# Core components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from models import get_model_registry
|
||||
|
||||
# Training components
|
||||
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
|
||||
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
|
||||
|
||||
# Utilities
|
||||
import torch
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/enhanced_trading.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedTradingSystem:
|
||||
"""Main enhanced trading system coordinator"""
|
||||
|
||||
def __init__(self, config_path: str = None):
|
||||
"""Initialize the enhanced trading system"""
|
||||
self.config = get_config(config_path)
|
||||
self.running = False
|
||||
|
||||
# Core components
|
||||
self.data_provider = DataProvider(self.config)
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
self.model_registry = get_model_registry()
|
||||
|
||||
# Training components
|
||||
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
|
||||
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
|
||||
|
||||
# Models
|
||||
self.cnn_models = {}
|
||||
self.rl_agents = {}
|
||||
|
||||
# Performance tracking
|
||||
self.performance_metrics = {
|
||||
'decisions_made': 0,
|
||||
'perfect_moves_marked': 0,
|
||||
'rl_experiences_added': 0,
|
||||
'training_sessions': 0
|
||||
}
|
||||
|
||||
logger.info("Enhanced Trading System initialized")
|
||||
logger.info(f"Symbols: {self.config.symbols}")
|
||||
logger.info(f"Timeframes: {self.config.timeframes}")
|
||||
|
||||
async def initialize_models(self, load_existing: bool = True):
|
||||
"""Initialize and register all models"""
|
||||
logger.info("Initializing models...")
|
||||
|
||||
# Initialize CNN models
|
||||
if load_existing:
|
||||
# Try to load existing CNN model
|
||||
if self.cnn_trainer.load_model('best_model.pt'):
|
||||
logger.info("Loaded existing CNN model")
|
||||
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||
else:
|
||||
logger.info("No existing CNN model found, using fresh model")
|
||||
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||
else:
|
||||
logger.info("Creating fresh CNN model")
|
||||
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||
|
||||
# Initialize RL agents
|
||||
if load_existing:
|
||||
# Try to load existing RL agents
|
||||
if self.rl_trainer.load_models():
|
||||
logger.info("Loaded existing RL models")
|
||||
else:
|
||||
logger.info("No existing RL models found, using fresh agents")
|
||||
|
||||
self.rl_agents = self.rl_trainer.get_agents()
|
||||
|
||||
# Register models with the orchestrator
|
||||
for model_name, model in self.cnn_models.items():
|
||||
if self.model_registry.register_model(model):
|
||||
logger.info(f"Registered CNN model: {model_name}")
|
||||
|
||||
for symbol, agent in self.rl_agents.items():
|
||||
if self.model_registry.register_model(agent):
|
||||
logger.info(f"Registered RL agent for {symbol}")
|
||||
|
||||
# Display memory usage
|
||||
memory_stats = self.model_registry.get_memory_stats()
|
||||
logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
||||
f"{memory_stats['total_limit_mb']:.1f}MB "
|
||||
f"({memory_stats['utilization_percent']:.1f}%)")
|
||||
|
||||
async def start_trading_loop(self):
|
||||
"""Start the main trading decision loop"""
|
||||
logger.info("Starting enhanced trading loop...")
|
||||
self.running = True
|
||||
|
||||
decision_count = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Make coordinated decisions for all symbols
|
||||
decisions = await self.orchestrator.make_coordinated_decisions()
|
||||
|
||||
# Process decisions
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
decision_count += 1
|
||||
self.performance_metrics['decisions_made'] += 1
|
||||
|
||||
logger.info(f"Trading Decision #{decision_count}")
|
||||
logger.info(f"Symbol: {symbol}")
|
||||
logger.info(f"Action: {decision.action}")
|
||||
logger.info(f"Confidence: {decision.confidence:.3f}")
|
||||
logger.info(f"Price: ${decision.price:.2f}")
|
||||
logger.info(f"Quantity: {decision.quantity:.6f}")
|
||||
|
||||
# Log timeframe analysis
|
||||
for tf_pred in decision.timeframe_analysis:
|
||||
logger.info(f" {tf_pred.timeframe}: {tf_pred.action} "
|
||||
f"(conf: {tf_pred.confidence:.3f})")
|
||||
|
||||
# Here you would integrate with actual trading execution
|
||||
# For now, we just log the decision
|
||||
|
||||
# Evaluate past actions with RL
|
||||
await self.orchestrator.evaluate_actions_with_rl()
|
||||
|
||||
# Check for perfect moves to mark
|
||||
perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10)
|
||||
if perfect_moves:
|
||||
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
|
||||
|
||||
# Log performance metrics every 10 decisions
|
||||
if decision_count % 10 == 0 and decision_count > 0:
|
||||
await self._log_performance_metrics()
|
||||
|
||||
# Wait before next decision cycle
|
||||
await asyncio.sleep(self.orchestrator.decision_frequency)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading loop: {e}")
|
||||
await asyncio.sleep(30) # Wait 30 seconds on error
|
||||
|
||||
async def start_training_loops(self):
|
||||
"""Start continuous training loops"""
|
||||
logger.info("Starting continuous training loops...")
|
||||
|
||||
# Start RL continuous learning
|
||||
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
|
||||
|
||||
# Start periodic CNN training
|
||||
cnn_task = asyncio.create_task(self._periodic_cnn_training())
|
||||
|
||||
return rl_task, cnn_task
|
||||
|
||||
async def _periodic_cnn_training(self):
|
||||
"""Periodic CNN training on accumulated perfect moves"""
|
||||
while self.running:
|
||||
try:
|
||||
# Wait for 6 hours between training sessions
|
||||
await asyncio.sleep(6 * 3600)
|
||||
|
||||
# Check if we have enough perfect moves for training
|
||||
perfect_moves = []
|
||||
for symbol in self.config.symbols:
|
||||
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) >= 200: # Minimum 200 perfect moves
|
||||
logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Train the CNN model
|
||||
training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200)
|
||||
|
||||
if training_report.get('training_completed'):
|
||||
self.performance_metrics['training_sessions'] += 1
|
||||
logger.info("CNN training completed successfully")
|
||||
logger.info(f"Final validation accuracy: "
|
||||
f"{training_report['final_metrics']['val_accuracy']:.4f}")
|
||||
|
||||
# Update the registered model
|
||||
updated_model = self.cnn_trainer.get_model()
|
||||
self.model_registry.unregister_model('enhanced_cnn')
|
||||
self.model_registry.register_model(updated_model)
|
||||
|
||||
else:
|
||||
logger.warning(f"CNN training failed: {training_report}")
|
||||
else:
|
||||
logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic CNN training: {e}")
|
||||
|
||||
async def _log_performance_metrics(self):
|
||||
"""Log system performance metrics"""
|
||||
logger.info("=== SYSTEM PERFORMANCE METRICS ===")
|
||||
logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}")
|
||||
logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}")
|
||||
logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}")
|
||||
|
||||
# Model registry stats
|
||||
memory_stats = self.model_registry.get_memory_stats()
|
||||
logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
||||
f"{memory_stats['total_limit_mb']:.1f}MB")
|
||||
|
||||
# RL performance
|
||||
rl_report = self.rl_trainer.get_performance_report()
|
||||
for symbol, agent_data in rl_report['agents'].items():
|
||||
logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, "
|
||||
f"Experiences={agent_data['experiences_stored']}, "
|
||||
f"Avg Reward={agent_data['avg_recent_reward']:.4f}")
|
||||
|
||||
# CNN model info
|
||||
for model_name, model in self.cnn_models.items():
|
||||
logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, "
|
||||
f"Device={model.device}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Graceful shutdown of the system"""
|
||||
logger.info("Shutting down Enhanced Trading System...")
|
||||
self.running = False
|
||||
|
||||
# Save models
|
||||
logger.info("Saving models...")
|
||||
self.cnn_trainer._save_model('shutdown_model.pt')
|
||||
self.rl_trainer._save_all_models()
|
||||
|
||||
# Clean up memory
|
||||
self.model_registry.cleanup_all_models()
|
||||
|
||||
# Generate final reports
|
||||
logger.info("Generating final reports...")
|
||||
|
||||
# CNN training plots
|
||||
if self.cnn_trainer.training_history['train_loss']:
|
||||
self.cnn_trainer._plot_training_history()
|
||||
|
||||
# RL training plots
|
||||
self.rl_trainer.plot_training_metrics()
|
||||
|
||||
logger.info("Enhanced Trading System shutdown complete")
|
||||
|
||||
def setup_signal_handlers(trading_system: EnhancedTradingSystem):
|
||||
"""Setup signal handlers for graceful shutdown"""
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating shutdown...")
|
||||
asyncio.create_task(trading_system.shutdown())
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async def main():
|
||||
"""Main application entry point"""
|
||||
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
|
||||
parser.add_argument('--config', type=str, help='Configuration file path')
|
||||
parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'],
|
||||
default='trade', help='Operation mode')
|
||||
parser.add_argument('--load-models', action='store_true', default=True,
|
||||
help='Load existing models')
|
||||
parser.add_argument('--no-load-models', action='store_false', dest='load_models',
|
||||
help="Don't load existing models")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create logs directory
|
||||
Path('logs').mkdir(exist_ok=True)
|
||||
|
||||
logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===")
|
||||
logger.info(f"Mode: {args.mode}")
|
||||
logger.info(f"Load existing models: {args.load_models}")
|
||||
logger.info(f"PyTorch version: {torch.__version__}")
|
||||
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||||
|
||||
# Initialize trading system
|
||||
trading_system = EnhancedTradingSystem(args.config)
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers(trading_system)
|
||||
|
||||
try:
|
||||
# Initialize models
|
||||
await trading_system.initialize_models(load_existing=args.load_models)
|
||||
|
||||
if args.mode == 'trade':
|
||||
# Start training loops
|
||||
rl_task, cnn_task = await trading_system.start_training_loops()
|
||||
|
||||
# Start main trading loop
|
||||
trading_task = asyncio.create_task(trading_system.start_trading_loop())
|
||||
|
||||
# Wait for any task to complete (or error)
|
||||
done, pending = await asyncio.wait(
|
||||
[trading_task, rl_task, cnn_task],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cancel remaining tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
elif args.mode == 'train':
|
||||
# Training-only mode
|
||||
logger.info("Running in training-only mode...")
|
||||
|
||||
# Train CNN if we have perfect moves
|
||||
perfect_moves = []
|
||||
for symbol in trading_system.config.symbols:
|
||||
symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) >= 100:
|
||||
logger.info(f"Training CNN on {len(perfect_moves)} perfect moves")
|
||||
training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100)
|
||||
logger.info(f"CNN training report: {training_report}")
|
||||
else:
|
||||
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}")
|
||||
|
||||
# Train RL agents if they have experiences
|
||||
await trading_system.rl_trainer._train_all_agents()
|
||||
|
||||
elif args.mode == 'backtest':
|
||||
# Backtesting mode
|
||||
logger.info("Backtesting mode not implemented yet")
|
||||
return
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||
finally:
|
||||
await trading_system.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the main application
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Application terminated by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user