gogo2/enhanced_trading_main.py
2025-05-26 16:02:40 +03:00

308 lines
13 KiB
Python

"""
Enhanced Multi-Modal Trading System - Main Application
This is the main launcher for the sophisticated trading system featuring:
1. Enhanced orchestrator coordinating CNN and RL modules
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
3. Perfect move marking for CNN training with known outcomes
4. Continuous RL learning from trading action evaluations
5. Market environment adaptation and coordinated decision making
"""
import asyncio
import logging
import signal
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
import argparse
# Core components
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from models import get_model_registry
# Training components
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
# Utilities
import torch
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/enhanced_trading.log')
]
)
logger = logging.getLogger(__name__)
class EnhancedTradingSystem:
"""Main enhanced trading system coordinator"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the enhanced trading system"""
self.config = get_config(config_path)
# Initialize core components
self.data_provider = DataProvider(self.config)
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
# Initialize training components
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
# Performance tracking
self.performance_metrics = {
'total_decisions': 0,
'profitable_decisions': 0,
'perfect_moves_marked': 0,
'cnn_training_sessions': 0,
'rl_training_steps': 0,
'start_time': datetime.now()
}
# System state
self.running = False
self.tasks = []
logger.info("Enhanced Trading System initialized")
logger.info(f"Symbols: {self.config.symbols}")
logger.info(f"Timeframes: {self.config.timeframes}")
logger.info("LEARNING SYSTEMS ACTIVE:")
logger.info("- RL agents learning from every trading decision")
logger.info("- CNN training on perfect moves with known outcomes")
logger.info("- Continuous pattern recognition and adaptation")
async def start(self):
"""Start the enhanced trading system"""
logger.info("Starting Enhanced Multi-Modal Trading System...")
self.running = True
try:
# Start all system components
trading_task = asyncio.create_task(self.start_trading_loop())
training_tasks = await self.start_training_loops()
monitoring_task = asyncio.create_task(self.start_monitoring_loop())
# Store tasks for cleanup
self.tasks = [trading_task, monitoring_task] + list(training_tasks)
# Wait for all tasks
await asyncio.gather(*self.tasks)
except KeyboardInterrupt:
logger.info("Shutdown signal received...")
await self.shutdown()
except Exception as e:
logger.error(f"System error: {e}")
await self.shutdown()
async def start_trading_loop(self):
"""Start the main trading decision loop"""
logger.info("Starting enhanced trading decision loop...")
decision_count = 0
while self.running:
try:
# Get coordinated decisions for all symbols
decisions = await self.orchestrator.make_coordinated_decisions()
for decision in decisions:
decision_count += 1
self.performance_metrics['total_decisions'] = decision_count
logger.info(f"DECISION #{decision_count}: {decision.action} {decision.symbol} "
f"@ ${decision.price:.2f} (Confidence: {decision.confidence:.1%})")
# Execute decision (this would connect to broker in live trading)
await self._execute_decision(decision)
# Add to RL evaluation queue for future learning
await self.orchestrator.queue_action_for_evaluation(decision)
# Check for perfect moves to train CNN
perfect_moves = self.orchestrator.get_recent_perfect_moves()
if perfect_moves:
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
logger.info(f"CNN LEARNING: {len(perfect_moves)} perfect moves identified for training")
# Log performance metrics every 10 decisions
if decision_count % 10 == 0 and decision_count > 0:
await self._log_performance_metrics()
# Wait before next decision cycle
await asyncio.sleep(self.orchestrator.decision_frequency)
except Exception as e:
logger.error(f"Error in trading loop: {e}")
await asyncio.sleep(30) # Wait 30 seconds on error
async def start_training_loops(self):
"""Start continuous training loops"""
logger.info("Starting continuous learning systems...")
# Start RL continuous learning
logger.info("STARTING RL CONTINUOUS LEARNING:")
logger.info("- Learning from every trading decision outcome")
logger.info("- Adapting to market regime changes")
logger.info("- Prioritized experience replay")
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
# Start periodic CNN training
logger.info("STARTING CNN PATTERN LEARNING:")
logger.info("- Training on perfect moves with known outcomes")
logger.info("- Multi-timeframe pattern recognition")
logger.info("- Retrospective learning from market data")
cnn_task = asyncio.create_task(self._periodic_cnn_training())
return rl_task, cnn_task
async def _periodic_cnn_training(self):
"""Periodically train CNN on perfect moves"""
training_interval = self.config.training.get('cnn_training_interval', 21600) # 6 hours
min_perfect_moves = self.config.training.get('min_perfect_moves', 200)
while self.running:
try:
# Check if we have enough perfect moves for training
perfect_moves = self.orchestrator.get_perfect_moves_for_training()
if len(perfect_moves) >= min_perfect_moves:
logger.info(f"CNN TRAINING: Starting with {len(perfect_moves)} perfect moves")
# Train CNN on perfect moves
training_results = self.cnn_trainer.train_on_perfect_moves(min_samples=min_perfect_moves)
if 'error' not in training_results:
self.performance_metrics['cnn_training_sessions'] += 1
logger.info(f"CNN TRAINING COMPLETED: Session #{self.performance_metrics['cnn_training_sessions']}")
logger.info(f"Training accuracy: {training_results.get('final_accuracy', 'N/A')}")
logger.info(f"Confidence accuracy: {training_results.get('confidence_accuracy', 'N/A')}")
else:
logger.warning(f"CNN training failed: {training_results['error']}")
else:
logger.info(f"CNN WAITING: Need {min_perfect_moves - len(perfect_moves)} more perfect moves for training")
# Wait for next training cycle
await asyncio.sleep(training_interval)
except Exception as e:
logger.error(f"Error in CNN training loop: {e}")
await asyncio.sleep(3600) # Wait 1 hour on error
async def start_monitoring_loop(self):
"""Monitor system performance and health"""
while self.running:
try:
# Monitor memory usage
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
logger.info(f"SYSTEM HEALTH: GPU Memory: {gpu_memory:.2f}GB")
# Monitor model performance
model_registry = get_model_registry()
for model_name, model in model_registry.models.items():
if hasattr(model, 'get_memory_usage'):
memory_mb = model.get_memory_usage()
logger.info(f"MODEL MEMORY: {model_name}: {memory_mb}MB")
# Monitor RL training progress
for symbol, agent in self.rl_trainer.agents.items():
buffer_size = len(agent.replay_buffer)
epsilon = agent.epsilon
logger.info(f"RL AGENT {symbol}: Buffer={buffer_size}, Epsilon={epsilon:.3f}")
await asyncio.sleep(300) # Monitor every 5 minutes
except Exception as e:
logger.error(f"Error in monitoring loop: {e}")
await asyncio.sleep(60)
async def _execute_decision(self, decision):
"""Execute trading decision (placeholder for broker integration)"""
# This is where we would connect to a real broker API
# For now, we just log the decision
logger.info(f"EXECUTING: {decision.action} {decision.symbol} @ ${decision.price:.2f}")
# Simulate execution delay
await asyncio.sleep(0.1)
# Mark as profitable for demo (in real trading, this would be determined by actual outcome)
if decision.confidence > 0.7:
self.performance_metrics['profitable_decisions'] += 1
async def _log_performance_metrics(self):
"""Log comprehensive performance metrics"""
runtime = datetime.now() - self.performance_metrics['start_time']
logger.info("PERFORMANCE METRICS:")
logger.info(f"Runtime: {runtime}")
logger.info(f"Total Decisions: {self.performance_metrics['total_decisions']}")
logger.info(f"Profitable Decisions: {self.performance_metrics['profitable_decisions']}")
logger.info(f"Perfect Moves Marked: {self.performance_metrics['perfect_moves_marked']}")
logger.info(f"CNN Training Sessions: {self.performance_metrics['cnn_training_sessions']}")
# Calculate success rate
if self.performance_metrics['total_decisions'] > 0:
success_rate = self.performance_metrics['profitable_decisions'] / self.performance_metrics['total_decisions']
logger.info(f"Success Rate: {success_rate:.1%}")
async def shutdown(self):
"""Gracefully shutdown the system"""
logger.info("Shutting down Enhanced Trading System...")
self.running = False
# Cancel all tasks
for task in self.tasks:
if not task.done():
task.cancel()
# Save models
try:
self.cnn_trainer._save_model('shutdown_model.pt')
self.rl_trainer._save_all_models()
logger.info("Models saved successfully")
except Exception as e:
logger.error(f"Error saving models: {e}")
# Final performance report
await self._log_performance_metrics()
logger.info("Enhanced Trading System shutdown complete")
async def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
parser.add_argument('--config', type=str, help='Path to configuration file')
parser.add_argument('--symbols', nargs='+', default=['ETH/USDT', 'BTC/USDT'],
help='Trading symbols')
parser.add_argument('--timeframes', nargs='+', default=['1s', '1m', '1h', '1d'],
help='Trading timeframes')
args = parser.parse_args()
# Create and start the enhanced trading system
system = EnhancedTradingSystem(args.config)
# Setup signal handlers for graceful shutdown
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}")
asyncio.create_task(system.shutdown())
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Start the system
await system.start()
if __name__ == "__main__":
# Ensure logs directory exists
Path('logs').mkdir(exist_ok=True)
# Run the enhanced trading system
asyncio.run(main())