308 lines
13 KiB
Python
308 lines
13 KiB
Python
"""
|
|
Enhanced Multi-Modal Trading System - Main Application
|
|
|
|
This is the main launcher for the sophisticated trading system featuring:
|
|
1. Enhanced orchestrator coordinating CNN and RL modules
|
|
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
|
|
3. Perfect move marking for CNN training with known outcomes
|
|
4. Continuous RL learning from trading action evaluations
|
|
5. Market environment adaptation and coordinated decision making
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import signal
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional
|
|
import argparse
|
|
|
|
# Core components
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from models import get_model_registry
|
|
|
|
# Training components
|
|
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
|
|
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
|
|
|
|
# Utilities
|
|
import torch
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler('logs/enhanced_trading.log')
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EnhancedTradingSystem:
|
|
"""Main enhanced trading system coordinator"""
|
|
|
|
def __init__(self, config_path: Optional[str] = None):
|
|
"""Initialize the enhanced trading system"""
|
|
self.config = get_config(config_path)
|
|
|
|
# Initialize core components
|
|
self.data_provider = DataProvider(self.config)
|
|
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
|
|
|
# Initialize training components
|
|
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
|
|
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
|
|
|
|
# Performance tracking
|
|
self.performance_metrics = {
|
|
'total_decisions': 0,
|
|
'profitable_decisions': 0,
|
|
'perfect_moves_marked': 0,
|
|
'cnn_training_sessions': 0,
|
|
'rl_training_steps': 0,
|
|
'start_time': datetime.now()
|
|
}
|
|
|
|
# System state
|
|
self.running = False
|
|
self.tasks = []
|
|
|
|
logger.info("Enhanced Trading System initialized")
|
|
logger.info(f"Symbols: {self.config.symbols}")
|
|
logger.info(f"Timeframes: {self.config.timeframes}")
|
|
logger.info("LEARNING SYSTEMS ACTIVE:")
|
|
logger.info("- RL agents learning from every trading decision")
|
|
logger.info("- CNN training on perfect moves with known outcomes")
|
|
logger.info("- Continuous pattern recognition and adaptation")
|
|
|
|
async def start(self):
|
|
"""Start the enhanced trading system"""
|
|
logger.info("Starting Enhanced Multi-Modal Trading System...")
|
|
self.running = True
|
|
|
|
try:
|
|
# Start all system components
|
|
trading_task = asyncio.create_task(self.start_trading_loop())
|
|
training_tasks = await self.start_training_loops()
|
|
monitoring_task = asyncio.create_task(self.start_monitoring_loop())
|
|
|
|
# Store tasks for cleanup
|
|
self.tasks = [trading_task, monitoring_task] + list(training_tasks)
|
|
|
|
# Wait for all tasks
|
|
await asyncio.gather(*self.tasks)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Shutdown signal received...")
|
|
await self.shutdown()
|
|
except Exception as e:
|
|
logger.error(f"System error: {e}")
|
|
await self.shutdown()
|
|
|
|
async def start_trading_loop(self):
|
|
"""Start the main trading decision loop"""
|
|
logger.info("Starting enhanced trading decision loop...")
|
|
decision_count = 0
|
|
|
|
while self.running:
|
|
try:
|
|
# Get coordinated decisions for all symbols
|
|
decisions = await self.orchestrator.make_coordinated_decisions()
|
|
|
|
for decision in decisions:
|
|
decision_count += 1
|
|
self.performance_metrics['total_decisions'] = decision_count
|
|
|
|
logger.info(f"DECISION #{decision_count}: {decision.action} {decision.symbol} "
|
|
f"@ ${decision.price:.2f} (Confidence: {decision.confidence:.1%})")
|
|
|
|
# Execute decision (this would connect to broker in live trading)
|
|
await self._execute_decision(decision)
|
|
|
|
# Add to RL evaluation queue for future learning
|
|
await self.orchestrator.queue_action_for_evaluation(decision)
|
|
|
|
# Check for perfect moves to train CNN
|
|
perfect_moves = self.orchestrator.get_recent_perfect_moves()
|
|
if perfect_moves:
|
|
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
|
|
logger.info(f"CNN LEARNING: {len(perfect_moves)} perfect moves identified for training")
|
|
|
|
# Log performance metrics every 10 decisions
|
|
if decision_count % 10 == 0 and decision_count > 0:
|
|
await self._log_performance_metrics()
|
|
|
|
# Wait before next decision cycle
|
|
await asyncio.sleep(self.orchestrator.decision_frequency)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in trading loop: {e}")
|
|
await asyncio.sleep(30) # Wait 30 seconds on error
|
|
|
|
async def start_training_loops(self):
|
|
"""Start continuous training loops"""
|
|
logger.info("Starting continuous learning systems...")
|
|
|
|
# Start RL continuous learning
|
|
logger.info("STARTING RL CONTINUOUS LEARNING:")
|
|
logger.info("- Learning from every trading decision outcome")
|
|
logger.info("- Adapting to market regime changes")
|
|
logger.info("- Prioritized experience replay")
|
|
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
|
|
|
|
# Start periodic CNN training
|
|
logger.info("STARTING CNN PATTERN LEARNING:")
|
|
logger.info("- Training on perfect moves with known outcomes")
|
|
logger.info("- Multi-timeframe pattern recognition")
|
|
logger.info("- Retrospective learning from market data")
|
|
cnn_task = asyncio.create_task(self._periodic_cnn_training())
|
|
|
|
return rl_task, cnn_task
|
|
|
|
async def _periodic_cnn_training(self):
|
|
"""Periodically train CNN on perfect moves"""
|
|
training_interval = self.config.training.get('cnn_training_interval', 21600) # 6 hours
|
|
min_perfect_moves = self.config.training.get('min_perfect_moves', 200)
|
|
|
|
while self.running:
|
|
try:
|
|
# Check if we have enough perfect moves for training
|
|
perfect_moves = self.orchestrator.get_perfect_moves_for_training()
|
|
|
|
if len(perfect_moves) >= min_perfect_moves:
|
|
logger.info(f"CNN TRAINING: Starting with {len(perfect_moves)} perfect moves")
|
|
|
|
# Train CNN on perfect moves
|
|
training_results = self.cnn_trainer.train_on_perfect_moves(min_samples=min_perfect_moves)
|
|
|
|
if 'error' not in training_results:
|
|
self.performance_metrics['cnn_training_sessions'] += 1
|
|
logger.info(f"CNN TRAINING COMPLETED: Session #{self.performance_metrics['cnn_training_sessions']}")
|
|
logger.info(f"Training accuracy: {training_results.get('final_accuracy', 'N/A')}")
|
|
logger.info(f"Confidence accuracy: {training_results.get('confidence_accuracy', 'N/A')}")
|
|
else:
|
|
logger.warning(f"CNN training failed: {training_results['error']}")
|
|
else:
|
|
logger.info(f"CNN WAITING: Need {min_perfect_moves - len(perfect_moves)} more perfect moves for training")
|
|
|
|
# Wait for next training cycle
|
|
await asyncio.sleep(training_interval)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training loop: {e}")
|
|
await asyncio.sleep(3600) # Wait 1 hour on error
|
|
|
|
async def start_monitoring_loop(self):
|
|
"""Monitor system performance and health"""
|
|
while self.running:
|
|
try:
|
|
# Monitor memory usage
|
|
if torch.cuda.is_available():
|
|
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
|
|
logger.info(f"SYSTEM HEALTH: GPU Memory: {gpu_memory:.2f}GB")
|
|
|
|
# Monitor model performance
|
|
model_registry = get_model_registry()
|
|
for model_name, model in model_registry.models.items():
|
|
if hasattr(model, 'get_memory_usage'):
|
|
memory_mb = model.get_memory_usage()
|
|
logger.info(f"MODEL MEMORY: {model_name}: {memory_mb}MB")
|
|
|
|
# Monitor RL training progress
|
|
for symbol, agent in self.rl_trainer.agents.items():
|
|
buffer_size = len(agent.replay_buffer)
|
|
epsilon = agent.epsilon
|
|
logger.info(f"RL AGENT {symbol}: Buffer={buffer_size}, Epsilon={epsilon:.3f}")
|
|
|
|
await asyncio.sleep(300) # Monitor every 5 minutes
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in monitoring loop: {e}")
|
|
await asyncio.sleep(60)
|
|
|
|
async def _execute_decision(self, decision):
|
|
"""Execute trading decision (placeholder for broker integration)"""
|
|
# This is where we would connect to a real broker API
|
|
# For now, we just log the decision
|
|
logger.info(f"EXECUTING: {decision.action} {decision.symbol} @ ${decision.price:.2f}")
|
|
|
|
# Simulate execution delay
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Mark as profitable for demo (in real trading, this would be determined by actual outcome)
|
|
if decision.confidence > 0.7:
|
|
self.performance_metrics['profitable_decisions'] += 1
|
|
|
|
async def _log_performance_metrics(self):
|
|
"""Log comprehensive performance metrics"""
|
|
runtime = datetime.now() - self.performance_metrics['start_time']
|
|
|
|
logger.info("PERFORMANCE METRICS:")
|
|
logger.info(f"Runtime: {runtime}")
|
|
logger.info(f"Total Decisions: {self.performance_metrics['total_decisions']}")
|
|
logger.info(f"Profitable Decisions: {self.performance_metrics['profitable_decisions']}")
|
|
logger.info(f"Perfect Moves Marked: {self.performance_metrics['perfect_moves_marked']}")
|
|
logger.info(f"CNN Training Sessions: {self.performance_metrics['cnn_training_sessions']}")
|
|
|
|
# Calculate success rate
|
|
if self.performance_metrics['total_decisions'] > 0:
|
|
success_rate = self.performance_metrics['profitable_decisions'] / self.performance_metrics['total_decisions']
|
|
logger.info(f"Success Rate: {success_rate:.1%}")
|
|
|
|
async def shutdown(self):
|
|
"""Gracefully shutdown the system"""
|
|
logger.info("Shutting down Enhanced Trading System...")
|
|
self.running = False
|
|
|
|
# Cancel all tasks
|
|
for task in self.tasks:
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
# Save models
|
|
try:
|
|
self.cnn_trainer._save_model('shutdown_model.pt')
|
|
self.rl_trainer._save_all_models()
|
|
logger.info("Models saved successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error saving models: {e}")
|
|
|
|
# Final performance report
|
|
await self._log_performance_metrics()
|
|
logger.info("Enhanced Trading System shutdown complete")
|
|
|
|
async def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
|
|
parser.add_argument('--config', type=str, help='Path to configuration file')
|
|
parser.add_argument('--symbols', nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
|
help='Trading symbols')
|
|
parser.add_argument('--timeframes', nargs='+', default=['1s', '1m', '1h', '1d'],
|
|
help='Trading timeframes')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create and start the enhanced trading system
|
|
system = EnhancedTradingSystem(args.config)
|
|
|
|
# Setup signal handlers for graceful shutdown
|
|
def signal_handler(signum, frame):
|
|
logger.info(f"Received signal {signum}")
|
|
asyncio.create_task(system.shutdown())
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
# Start the system
|
|
await system.start()
|
|
|
|
if __name__ == "__main__":
|
|
# Ensure logs directory exists
|
|
Path('logs').mkdir(exist_ok=True)
|
|
|
|
# Run the enhanced trading system
|
|
asyncio.run(main()) |