scalping dash also works initially
This commit is contained in:
@ -45,118 +45,92 @@ logger = logging.getLogger(__name__)
|
||||
class EnhancedTradingSystem:
|
||||
"""Main enhanced trading system coordinator"""
|
||||
|
||||
def __init__(self, config_path: str = None):
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize the enhanced trading system"""
|
||||
self.config = get_config(config_path)
|
||||
self.running = False
|
||||
|
||||
# Core components
|
||||
# Initialize core components
|
||||
self.data_provider = DataProvider(self.config)
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
self.model_registry = get_model_registry()
|
||||
|
||||
# Training components
|
||||
# Initialize training components
|
||||
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
|
||||
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
|
||||
|
||||
# Models
|
||||
self.cnn_models = {}
|
||||
self.rl_agents = {}
|
||||
|
||||
# Performance tracking
|
||||
self.performance_metrics = {
|
||||
'decisions_made': 0,
|
||||
'total_decisions': 0,
|
||||
'profitable_decisions': 0,
|
||||
'perfect_moves_marked': 0,
|
||||
'rl_experiences_added': 0,
|
||||
'training_sessions': 0
|
||||
'cnn_training_sessions': 0,
|
||||
'rl_training_steps': 0,
|
||||
'start_time': datetime.now()
|
||||
}
|
||||
|
||||
# System state
|
||||
self.running = False
|
||||
self.tasks = []
|
||||
|
||||
logger.info("Enhanced Trading System initialized")
|
||||
logger.info(f"Symbols: {self.config.symbols}")
|
||||
logger.info(f"Timeframes: {self.config.timeframes}")
|
||||
logger.info("LEARNING SYSTEMS ACTIVE:")
|
||||
logger.info("- RL agents learning from every trading decision")
|
||||
logger.info("- CNN training on perfect moves with known outcomes")
|
||||
logger.info("- Continuous pattern recognition and adaptation")
|
||||
|
||||
async def initialize_models(self, load_existing: bool = True):
|
||||
"""Initialize and register all models"""
|
||||
logger.info("Initializing models...")
|
||||
async def start(self):
|
||||
"""Start the enhanced trading system"""
|
||||
logger.info("Starting Enhanced Multi-Modal Trading System...")
|
||||
self.running = True
|
||||
|
||||
# Initialize CNN models
|
||||
if load_existing:
|
||||
# Try to load existing CNN model
|
||||
if self.cnn_trainer.load_model('best_model.pt'):
|
||||
logger.info("Loaded existing CNN model")
|
||||
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||
else:
|
||||
logger.info("No existing CNN model found, using fresh model")
|
||||
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||
else:
|
||||
logger.info("Creating fresh CNN model")
|
||||
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||
|
||||
# Initialize RL agents
|
||||
if load_existing:
|
||||
# Try to load existing RL agents
|
||||
if self.rl_trainer.load_models():
|
||||
logger.info("Loaded existing RL models")
|
||||
else:
|
||||
logger.info("No existing RL models found, using fresh agents")
|
||||
|
||||
self.rl_agents = self.rl_trainer.get_agents()
|
||||
|
||||
# Register models with the orchestrator
|
||||
for model_name, model in self.cnn_models.items():
|
||||
if self.model_registry.register_model(model):
|
||||
logger.info(f"Registered CNN model: {model_name}")
|
||||
|
||||
for symbol, agent in self.rl_agents.items():
|
||||
if self.model_registry.register_model(agent):
|
||||
logger.info(f"Registered RL agent for {symbol}")
|
||||
|
||||
# Display memory usage
|
||||
memory_stats = self.model_registry.get_memory_stats()
|
||||
logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
||||
f"{memory_stats['total_limit_mb']:.1f}MB "
|
||||
f"({memory_stats['utilization_percent']:.1f}%)")
|
||||
try:
|
||||
# Start all system components
|
||||
trading_task = asyncio.create_task(self.start_trading_loop())
|
||||
training_tasks = await self.start_training_loops()
|
||||
monitoring_task = asyncio.create_task(self.start_monitoring_loop())
|
||||
|
||||
# Store tasks for cleanup
|
||||
self.tasks = [trading_task, monitoring_task] + list(training_tasks)
|
||||
|
||||
# Wait for all tasks
|
||||
await asyncio.gather(*self.tasks)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutdown signal received...")
|
||||
await self.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"System error: {e}")
|
||||
await self.shutdown()
|
||||
|
||||
async def start_trading_loop(self):
|
||||
"""Start the main trading decision loop"""
|
||||
logger.info("Starting enhanced trading loop...")
|
||||
self.running = True
|
||||
|
||||
logger.info("Starting enhanced trading decision loop...")
|
||||
decision_count = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Make coordinated decisions for all symbols
|
||||
# Get coordinated decisions for all symbols
|
||||
decisions = await self.orchestrator.make_coordinated_decisions()
|
||||
|
||||
# Process decisions
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
decision_count += 1
|
||||
self.performance_metrics['decisions_made'] += 1
|
||||
|
||||
logger.info(f"Trading Decision #{decision_count}")
|
||||
logger.info(f"Symbol: {symbol}")
|
||||
logger.info(f"Action: {decision.action}")
|
||||
logger.info(f"Confidence: {decision.confidence:.3f}")
|
||||
logger.info(f"Price: ${decision.price:.2f}")
|
||||
logger.info(f"Quantity: {decision.quantity:.6f}")
|
||||
|
||||
# Log timeframe analysis
|
||||
for tf_pred in decision.timeframe_analysis:
|
||||
logger.info(f" {tf_pred.timeframe}: {tf_pred.action} "
|
||||
f"(conf: {tf_pred.confidence:.3f})")
|
||||
|
||||
# Here you would integrate with actual trading execution
|
||||
# For now, we just log the decision
|
||||
|
||||
# Evaluate past actions with RL
|
||||
await self.orchestrator.evaluate_actions_with_rl()
|
||||
for decision in decisions:
|
||||
decision_count += 1
|
||||
self.performance_metrics['total_decisions'] = decision_count
|
||||
|
||||
logger.info(f"DECISION #{decision_count}: {decision.action} {decision.symbol} "
|
||||
f"@ ${decision.price:.2f} (Confidence: {decision.confidence:.1%})")
|
||||
|
||||
# Execute decision (this would connect to broker in live trading)
|
||||
await self._execute_decision(decision)
|
||||
|
||||
# Add to RL evaluation queue for future learning
|
||||
await self.orchestrator.queue_action_for_evaluation(decision)
|
||||
|
||||
# Check for perfect moves to mark
|
||||
perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10)
|
||||
# Check for perfect moves to train CNN
|
||||
perfect_moves = self.orchestrator.get_recent_perfect_moves()
|
||||
if perfect_moves:
|
||||
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
|
||||
logger.info(f"CNN LEARNING: {len(perfect_moves)} perfect moves identified for training")
|
||||
|
||||
# Log performance metrics every 10 decisions
|
||||
if decision_count % 10 == 0 and decision_count > 0:
|
||||
@ -171,200 +145,164 @@ class EnhancedTradingSystem:
|
||||
|
||||
async def start_training_loops(self):
|
||||
"""Start continuous training loops"""
|
||||
logger.info("Starting continuous training loops...")
|
||||
logger.info("Starting continuous learning systems...")
|
||||
|
||||
# Start RL continuous learning
|
||||
logger.info("STARTING RL CONTINUOUS LEARNING:")
|
||||
logger.info("- Learning from every trading decision outcome")
|
||||
logger.info("- Adapting to market regime changes")
|
||||
logger.info("- Prioritized experience replay")
|
||||
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
|
||||
|
||||
# Start periodic CNN training
|
||||
logger.info("STARTING CNN PATTERN LEARNING:")
|
||||
logger.info("- Training on perfect moves with known outcomes")
|
||||
logger.info("- Multi-timeframe pattern recognition")
|
||||
logger.info("- Retrospective learning from market data")
|
||||
cnn_task = asyncio.create_task(self._periodic_cnn_training())
|
||||
|
||||
return rl_task, cnn_task
|
||||
|
||||
async def _periodic_cnn_training(self):
|
||||
"""Periodic CNN training on accumulated perfect moves"""
|
||||
"""Periodically train CNN on perfect moves"""
|
||||
training_interval = self.config.training.get('cnn_training_interval', 21600) # 6 hours
|
||||
min_perfect_moves = self.config.training.get('min_perfect_moves', 200)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Wait for 6 hours between training sessions
|
||||
await asyncio.sleep(6 * 3600)
|
||||
|
||||
# Check if we have enough perfect moves for training
|
||||
perfect_moves = []
|
||||
for symbol in self.config.symbols:
|
||||
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
perfect_moves = self.orchestrator.get_perfect_moves_for_training()
|
||||
|
||||
if len(perfect_moves) >= 200: # Minimum 200 perfect moves
|
||||
logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves")
|
||||
if len(perfect_moves) >= min_perfect_moves:
|
||||
logger.info(f"CNN TRAINING: Starting with {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Train the CNN model
|
||||
training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200)
|
||||
# Train CNN on perfect moves
|
||||
training_results = self.cnn_trainer.train_on_perfect_moves(min_samples=min_perfect_moves)
|
||||
|
||||
if training_report.get('training_completed'):
|
||||
self.performance_metrics['training_sessions'] += 1
|
||||
logger.info("CNN training completed successfully")
|
||||
logger.info(f"Final validation accuracy: "
|
||||
f"{training_report['final_metrics']['val_accuracy']:.4f}")
|
||||
|
||||
# Update the registered model
|
||||
updated_model = self.cnn_trainer.get_model()
|
||||
self.model_registry.unregister_model('enhanced_cnn')
|
||||
self.model_registry.register_model(updated_model)
|
||||
|
||||
if 'error' not in training_results:
|
||||
self.performance_metrics['cnn_training_sessions'] += 1
|
||||
logger.info(f"CNN TRAINING COMPLETED: Session #{self.performance_metrics['cnn_training_sessions']}")
|
||||
logger.info(f"Training accuracy: {training_results.get('final_accuracy', 'N/A')}")
|
||||
logger.info(f"Confidence accuracy: {training_results.get('confidence_accuracy', 'N/A')}")
|
||||
else:
|
||||
logger.warning(f"CNN training failed: {training_report}")
|
||||
logger.warning(f"CNN training failed: {training_results['error']}")
|
||||
else:
|
||||
logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200")
|
||||
logger.info(f"CNN WAITING: Need {min_perfect_moves - len(perfect_moves)} more perfect moves for training")
|
||||
|
||||
# Wait for next training cycle
|
||||
await asyncio.sleep(training_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic CNN training: {e}")
|
||||
logger.error(f"Error in CNN training loop: {e}")
|
||||
await asyncio.sleep(3600) # Wait 1 hour on error
|
||||
|
||||
async def start_monitoring_loop(self):
|
||||
"""Monitor system performance and health"""
|
||||
while self.running:
|
||||
try:
|
||||
# Monitor memory usage
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
|
||||
logger.info(f"SYSTEM HEALTH: GPU Memory: {gpu_memory:.2f}GB")
|
||||
|
||||
# Monitor model performance
|
||||
model_registry = get_model_registry()
|
||||
for model_name, model in model_registry.models.items():
|
||||
if hasattr(model, 'get_memory_usage'):
|
||||
memory_mb = model.get_memory_usage()
|
||||
logger.info(f"MODEL MEMORY: {model_name}: {memory_mb}MB")
|
||||
|
||||
# Monitor RL training progress
|
||||
for symbol, agent in self.rl_trainer.agents.items():
|
||||
buffer_size = len(agent.replay_buffer)
|
||||
epsilon = agent.epsilon
|
||||
logger.info(f"RL AGENT {symbol}: Buffer={buffer_size}, Epsilon={epsilon:.3f}")
|
||||
|
||||
await asyncio.sleep(300) # Monitor every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _execute_decision(self, decision):
|
||||
"""Execute trading decision (placeholder for broker integration)"""
|
||||
# This is where we would connect to a real broker API
|
||||
# For now, we just log the decision
|
||||
logger.info(f"EXECUTING: {decision.action} {decision.symbol} @ ${decision.price:.2f}")
|
||||
|
||||
# Simulate execution delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mark as profitable for demo (in real trading, this would be determined by actual outcome)
|
||||
if decision.confidence > 0.7:
|
||||
self.performance_metrics['profitable_decisions'] += 1
|
||||
|
||||
async def _log_performance_metrics(self):
|
||||
"""Log system performance metrics"""
|
||||
logger.info("=== SYSTEM PERFORMANCE METRICS ===")
|
||||
logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}")
|
||||
logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}")
|
||||
logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}")
|
||||
"""Log comprehensive performance metrics"""
|
||||
runtime = datetime.now() - self.performance_metrics['start_time']
|
||||
|
||||
# Model registry stats
|
||||
memory_stats = self.model_registry.get_memory_stats()
|
||||
logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
||||
f"{memory_stats['total_limit_mb']:.1f}MB")
|
||||
logger.info("PERFORMANCE METRICS:")
|
||||
logger.info(f"Runtime: {runtime}")
|
||||
logger.info(f"Total Decisions: {self.performance_metrics['total_decisions']}")
|
||||
logger.info(f"Profitable Decisions: {self.performance_metrics['profitable_decisions']}")
|
||||
logger.info(f"Perfect Moves Marked: {self.performance_metrics['perfect_moves_marked']}")
|
||||
logger.info(f"CNN Training Sessions: {self.performance_metrics['cnn_training_sessions']}")
|
||||
|
||||
# RL performance
|
||||
rl_report = self.rl_trainer.get_performance_report()
|
||||
for symbol, agent_data in rl_report['agents'].items():
|
||||
logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, "
|
||||
f"Experiences={agent_data['experiences_stored']}, "
|
||||
f"Avg Reward={agent_data['avg_recent_reward']:.4f}")
|
||||
|
||||
# CNN model info
|
||||
for model_name, model in self.cnn_models.items():
|
||||
logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, "
|
||||
f"Device={model.device}")
|
||||
# Calculate success rate
|
||||
if self.performance_metrics['total_decisions'] > 0:
|
||||
success_rate = self.performance_metrics['profitable_decisions'] / self.performance_metrics['total_decisions']
|
||||
logger.info(f"Success Rate: {success_rate:.1%}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Graceful shutdown of the system"""
|
||||
"""Gracefully shutdown the system"""
|
||||
logger.info("Shutting down Enhanced Trading System...")
|
||||
self.running = False
|
||||
|
||||
# Cancel all tasks
|
||||
for task in self.tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Save models
|
||||
logger.info("Saving models...")
|
||||
self.cnn_trainer._save_model('shutdown_model.pt')
|
||||
self.rl_trainer._save_all_models()
|
||||
|
||||
# Clean up memory
|
||||
self.model_registry.cleanup_all_models()
|
||||
|
||||
# Generate final reports
|
||||
logger.info("Generating final reports...")
|
||||
|
||||
# CNN training plots
|
||||
if self.cnn_trainer.training_history['train_loss']:
|
||||
self.cnn_trainer._plot_training_history()
|
||||
|
||||
# RL training plots
|
||||
self.rl_trainer.plot_training_metrics()
|
||||
try:
|
||||
self.cnn_trainer._save_model('shutdown_model.pt')
|
||||
self.rl_trainer._save_all_models()
|
||||
logger.info("Models saved successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving models: {e}")
|
||||
|
||||
# Final performance report
|
||||
await self._log_performance_metrics()
|
||||
logger.info("Enhanced Trading System shutdown complete")
|
||||
|
||||
def setup_signal_handlers(trading_system: EnhancedTradingSystem):
|
||||
"""Setup signal handlers for graceful shutdown"""
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating shutdown...")
|
||||
asyncio.create_task(trading_system.shutdown())
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async def main():
|
||||
"""Main application entry point"""
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
|
||||
parser.add_argument('--config', type=str, help='Configuration file path')
|
||||
parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'],
|
||||
default='trade', help='Operation mode')
|
||||
parser.add_argument('--load-models', action='store_true', default=True,
|
||||
help='Load existing models')
|
||||
parser.add_argument('--no-load-models', action='store_false', dest='load_models',
|
||||
help="Don't load existing models")
|
||||
parser.add_argument('--config', type=str, help='Path to configuration file')
|
||||
parser.add_argument('--symbols', nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
||||
help='Trading symbols')
|
||||
parser.add_argument('--timeframes', nargs='+', default=['1s', '1m', '1h', '1d'],
|
||||
help='Trading timeframes')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create logs directory
|
||||
Path('logs').mkdir(exist_ok=True)
|
||||
# Create and start the enhanced trading system
|
||||
system = EnhancedTradingSystem(args.config)
|
||||
|
||||
logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===")
|
||||
logger.info(f"Mode: {args.mode}")
|
||||
logger.info(f"Load existing models: {args.load_models}")
|
||||
logger.info(f"PyTorch version: {torch.__version__}")
|
||||
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||||
# Setup signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}")
|
||||
asyncio.create_task(system.shutdown())
|
||||
|
||||
# Initialize trading system
|
||||
trading_system = EnhancedTradingSystem(args.config)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers(trading_system)
|
||||
|
||||
try:
|
||||
# Initialize models
|
||||
await trading_system.initialize_models(load_existing=args.load_models)
|
||||
|
||||
if args.mode == 'trade':
|
||||
# Start training loops
|
||||
rl_task, cnn_task = await trading_system.start_training_loops()
|
||||
|
||||
# Start main trading loop
|
||||
trading_task = asyncio.create_task(trading_system.start_trading_loop())
|
||||
|
||||
# Wait for any task to complete (or error)
|
||||
done, pending = await asyncio.wait(
|
||||
[trading_task, rl_task, cnn_task],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cancel remaining tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
elif args.mode == 'train':
|
||||
# Training-only mode
|
||||
logger.info("Running in training-only mode...")
|
||||
|
||||
# Train CNN if we have perfect moves
|
||||
perfect_moves = []
|
||||
for symbol in trading_system.config.symbols:
|
||||
symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) >= 100:
|
||||
logger.info(f"Training CNN on {len(perfect_moves)} perfect moves")
|
||||
training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100)
|
||||
logger.info(f"CNN training report: {training_report}")
|
||||
else:
|
||||
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}")
|
||||
|
||||
# Train RL agents if they have experiences
|
||||
await trading_system.rl_trainer._train_all_agents()
|
||||
|
||||
elif args.mode == 'backtest':
|
||||
# Backtesting mode
|
||||
logger.info("Backtesting mode not implemented yet")
|
||||
return
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||
finally:
|
||||
await trading_system.shutdown()
|
||||
# Start the system
|
||||
await system.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the main application
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Application terminated by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
# Ensure logs directory exists
|
||||
Path('logs').mkdir(exist_ok=True)
|
||||
|
||||
# Run the enhanced trading system
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user