370 lines
15 KiB
Python
370 lines
15 KiB
Python
"""
|
|
Enhanced Multi-Modal Trading System - Main Application
|
|
|
|
This is the main launcher for the sophisticated trading system featuring:
|
|
1. Enhanced orchestrator coordinating CNN and RL modules
|
|
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
|
|
3. Perfect move marking for CNN training with known outcomes
|
|
4. Continuous RL learning from trading action evaluations
|
|
5. Market environment adaptation and coordinated decision making
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import signal
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional
|
|
import argparse
|
|
|
|
# Core components
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from models import get_model_registry
|
|
|
|
# Training components
|
|
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
|
|
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
|
|
|
|
# Utilities
|
|
import torch
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler('logs/enhanced_trading.log')
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EnhancedTradingSystem:
|
|
"""Main enhanced trading system coordinator"""
|
|
|
|
def __init__(self, config_path: str = None):
|
|
"""Initialize the enhanced trading system"""
|
|
self.config = get_config(config_path)
|
|
self.running = False
|
|
|
|
# Core components
|
|
self.data_provider = DataProvider(self.config)
|
|
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
|
self.model_registry = get_model_registry()
|
|
|
|
# Training components
|
|
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
|
|
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
|
|
|
|
# Models
|
|
self.cnn_models = {}
|
|
self.rl_agents = {}
|
|
|
|
# Performance tracking
|
|
self.performance_metrics = {
|
|
'decisions_made': 0,
|
|
'perfect_moves_marked': 0,
|
|
'rl_experiences_added': 0,
|
|
'training_sessions': 0
|
|
}
|
|
|
|
logger.info("Enhanced Trading System initialized")
|
|
logger.info(f"Symbols: {self.config.symbols}")
|
|
logger.info(f"Timeframes: {self.config.timeframes}")
|
|
|
|
async def initialize_models(self, load_existing: bool = True):
|
|
"""Initialize and register all models"""
|
|
logger.info("Initializing models...")
|
|
|
|
# Initialize CNN models
|
|
if load_existing:
|
|
# Try to load existing CNN model
|
|
if self.cnn_trainer.load_model('best_model.pt'):
|
|
logger.info("Loaded existing CNN model")
|
|
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
|
else:
|
|
logger.info("No existing CNN model found, using fresh model")
|
|
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
|
else:
|
|
logger.info("Creating fresh CNN model")
|
|
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
|
|
|
# Initialize RL agents
|
|
if load_existing:
|
|
# Try to load existing RL agents
|
|
if self.rl_trainer.load_models():
|
|
logger.info("Loaded existing RL models")
|
|
else:
|
|
logger.info("No existing RL models found, using fresh agents")
|
|
|
|
self.rl_agents = self.rl_trainer.get_agents()
|
|
|
|
# Register models with the orchestrator
|
|
for model_name, model in self.cnn_models.items():
|
|
if self.model_registry.register_model(model):
|
|
logger.info(f"Registered CNN model: {model_name}")
|
|
|
|
for symbol, agent in self.rl_agents.items():
|
|
if self.model_registry.register_model(agent):
|
|
logger.info(f"Registered RL agent for {symbol}")
|
|
|
|
# Display memory usage
|
|
memory_stats = self.model_registry.get_memory_stats()
|
|
logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
|
f"{memory_stats['total_limit_mb']:.1f}MB "
|
|
f"({memory_stats['utilization_percent']:.1f}%)")
|
|
|
|
async def start_trading_loop(self):
|
|
"""Start the main trading decision loop"""
|
|
logger.info("Starting enhanced trading loop...")
|
|
self.running = True
|
|
|
|
decision_count = 0
|
|
|
|
while self.running:
|
|
try:
|
|
# Make coordinated decisions for all symbols
|
|
decisions = await self.orchestrator.make_coordinated_decisions()
|
|
|
|
# Process decisions
|
|
for symbol, decision in decisions.items():
|
|
if decision:
|
|
decision_count += 1
|
|
self.performance_metrics['decisions_made'] += 1
|
|
|
|
logger.info(f"Trading Decision #{decision_count}")
|
|
logger.info(f"Symbol: {symbol}")
|
|
logger.info(f"Action: {decision.action}")
|
|
logger.info(f"Confidence: {decision.confidence:.3f}")
|
|
logger.info(f"Price: ${decision.price:.2f}")
|
|
logger.info(f"Quantity: {decision.quantity:.6f}")
|
|
|
|
# Log timeframe analysis
|
|
for tf_pred in decision.timeframe_analysis:
|
|
logger.info(f" {tf_pred.timeframe}: {tf_pred.action} "
|
|
f"(conf: {tf_pred.confidence:.3f})")
|
|
|
|
# Here you would integrate with actual trading execution
|
|
# For now, we just log the decision
|
|
|
|
# Evaluate past actions with RL
|
|
await self.orchestrator.evaluate_actions_with_rl()
|
|
|
|
# Check for perfect moves to mark
|
|
perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10)
|
|
if perfect_moves:
|
|
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
|
|
|
|
# Log performance metrics every 10 decisions
|
|
if decision_count % 10 == 0 and decision_count > 0:
|
|
await self._log_performance_metrics()
|
|
|
|
# Wait before next decision cycle
|
|
await asyncio.sleep(self.orchestrator.decision_frequency)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in trading loop: {e}")
|
|
await asyncio.sleep(30) # Wait 30 seconds on error
|
|
|
|
async def start_training_loops(self):
|
|
"""Start continuous training loops"""
|
|
logger.info("Starting continuous training loops...")
|
|
|
|
# Start RL continuous learning
|
|
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
|
|
|
|
# Start periodic CNN training
|
|
cnn_task = asyncio.create_task(self._periodic_cnn_training())
|
|
|
|
return rl_task, cnn_task
|
|
|
|
async def _periodic_cnn_training(self):
|
|
"""Periodic CNN training on accumulated perfect moves"""
|
|
while self.running:
|
|
try:
|
|
# Wait for 6 hours between training sessions
|
|
await asyncio.sleep(6 * 3600)
|
|
|
|
# Check if we have enough perfect moves for training
|
|
perfect_moves = []
|
|
for symbol in self.config.symbols:
|
|
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
|
perfect_moves.extend(symbol_moves)
|
|
|
|
if len(perfect_moves) >= 200: # Minimum 200 perfect moves
|
|
logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves")
|
|
|
|
# Train the CNN model
|
|
training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200)
|
|
|
|
if training_report.get('training_completed'):
|
|
self.performance_metrics['training_sessions'] += 1
|
|
logger.info("CNN training completed successfully")
|
|
logger.info(f"Final validation accuracy: "
|
|
f"{training_report['final_metrics']['val_accuracy']:.4f}")
|
|
|
|
# Update the registered model
|
|
updated_model = self.cnn_trainer.get_model()
|
|
self.model_registry.unregister_model('enhanced_cnn')
|
|
self.model_registry.register_model(updated_model)
|
|
|
|
else:
|
|
logger.warning(f"CNN training failed: {training_report}")
|
|
else:
|
|
logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in periodic CNN training: {e}")
|
|
|
|
async def _log_performance_metrics(self):
|
|
"""Log system performance metrics"""
|
|
logger.info("=== SYSTEM PERFORMANCE METRICS ===")
|
|
logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}")
|
|
logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}")
|
|
logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}")
|
|
|
|
# Model registry stats
|
|
memory_stats = self.model_registry.get_memory_stats()
|
|
logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
|
f"{memory_stats['total_limit_mb']:.1f}MB")
|
|
|
|
# RL performance
|
|
rl_report = self.rl_trainer.get_performance_report()
|
|
for symbol, agent_data in rl_report['agents'].items():
|
|
logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, "
|
|
f"Experiences={agent_data['experiences_stored']}, "
|
|
f"Avg Reward={agent_data['avg_recent_reward']:.4f}")
|
|
|
|
# CNN model info
|
|
for model_name, model in self.cnn_models.items():
|
|
logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, "
|
|
f"Device={model.device}")
|
|
|
|
async def shutdown(self):
|
|
"""Graceful shutdown of the system"""
|
|
logger.info("Shutting down Enhanced Trading System...")
|
|
self.running = False
|
|
|
|
# Save models
|
|
logger.info("Saving models...")
|
|
self.cnn_trainer._save_model('shutdown_model.pt')
|
|
self.rl_trainer._save_all_models()
|
|
|
|
# Clean up memory
|
|
self.model_registry.cleanup_all_models()
|
|
|
|
# Generate final reports
|
|
logger.info("Generating final reports...")
|
|
|
|
# CNN training plots
|
|
if self.cnn_trainer.training_history['train_loss']:
|
|
self.cnn_trainer._plot_training_history()
|
|
|
|
# RL training plots
|
|
self.rl_trainer.plot_training_metrics()
|
|
|
|
logger.info("Enhanced Trading System shutdown complete")
|
|
|
|
def setup_signal_handlers(trading_system: EnhancedTradingSystem):
|
|
"""Setup signal handlers for graceful shutdown"""
|
|
def signal_handler(signum, frame):
|
|
logger.info(f"Received signal {signum}, initiating shutdown...")
|
|
asyncio.create_task(trading_system.shutdown())
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
async def main():
|
|
"""Main application entry point"""
|
|
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
|
|
parser.add_argument('--config', type=str, help='Configuration file path')
|
|
parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'],
|
|
default='trade', help='Operation mode')
|
|
parser.add_argument('--load-models', action='store_true', default=True,
|
|
help='Load existing models')
|
|
parser.add_argument('--no-load-models', action='store_false', dest='load_models',
|
|
help="Don't load existing models")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create logs directory
|
|
Path('logs').mkdir(exist_ok=True)
|
|
|
|
logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===")
|
|
logger.info(f"Mode: {args.mode}")
|
|
logger.info(f"Load existing models: {args.load_models}")
|
|
logger.info(f"PyTorch version: {torch.__version__}")
|
|
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
|
|
|
# Initialize trading system
|
|
trading_system = EnhancedTradingSystem(args.config)
|
|
|
|
# Setup signal handlers
|
|
setup_signal_handlers(trading_system)
|
|
|
|
try:
|
|
# Initialize models
|
|
await trading_system.initialize_models(load_existing=args.load_models)
|
|
|
|
if args.mode == 'trade':
|
|
# Start training loops
|
|
rl_task, cnn_task = await trading_system.start_training_loops()
|
|
|
|
# Start main trading loop
|
|
trading_task = asyncio.create_task(trading_system.start_trading_loop())
|
|
|
|
# Wait for any task to complete (or error)
|
|
done, pending = await asyncio.wait(
|
|
[trading_task, rl_task, cnn_task],
|
|
return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
|
|
# Cancel remaining tasks
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
elif args.mode == 'train':
|
|
# Training-only mode
|
|
logger.info("Running in training-only mode...")
|
|
|
|
# Train CNN if we have perfect moves
|
|
perfect_moves = []
|
|
for symbol in trading_system.config.symbols:
|
|
symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
|
perfect_moves.extend(symbol_moves)
|
|
|
|
if len(perfect_moves) >= 100:
|
|
logger.info(f"Training CNN on {len(perfect_moves)} perfect moves")
|
|
training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100)
|
|
logger.info(f"CNN training report: {training_report}")
|
|
else:
|
|
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}")
|
|
|
|
# Train RL agents if they have experiences
|
|
await trading_system.rl_trainer._train_all_agents()
|
|
|
|
elif args.mode == 'backtest':
|
|
# Backtesting mode
|
|
logger.info("Backtesting mode not implemented yet")
|
|
return
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Received keyboard interrupt")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error: {e}", exc_info=True)
|
|
finally:
|
|
await trading_system.shutdown()
|
|
|
|
if __name__ == "__main__":
|
|
# Run the main application
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Application terminated by user")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}", exc_info=True)
|
|
sys.exit(1) |