491 lines
20 KiB
Python
491 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Continuous Full Training System (RL + CNN)
|
|
|
|
This system runs continuous training for both RL and CNN models using the enhanced
|
|
DataProvider for consistent data streaming to both models and the dashboard.
|
|
|
|
Features:
|
|
- Single DataProvider instance for all data needs
|
|
- Continuous RL training with real-time market data
|
|
- CNN training with perfect move detection
|
|
- Real-time performance monitoring
|
|
- Automatic model checkpointing
|
|
- Integration with live trading dashboard
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import signal
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
from threading import Thread, Event
|
|
from typing import Dict, Any
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('logs/continuous_training.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import our components
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider, MarketTick
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from web.scalping_dashboard import RealTimeScalpingDashboard
|
|
|
|
class ContinuousTrainingSystem:
|
|
"""Comprehensive continuous training system for RL + CNN models"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the continuous training system"""
|
|
self.config = get_config()
|
|
|
|
# Single DataProvider instance for all data needs
|
|
self.data_provider = DataProvider(
|
|
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
timeframes=['1s', '1m', '1h', '1d']
|
|
)
|
|
|
|
# Enhanced orchestrator for AI trading
|
|
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
|
|
|
# Dashboard for monitoring
|
|
self.dashboard = None
|
|
|
|
# Training control
|
|
self.running = False
|
|
self.shutdown_event = Event()
|
|
|
|
# Performance tracking
|
|
self.training_stats = {
|
|
'start_time': None,
|
|
'rl_training_cycles': 0,
|
|
'cnn_training_cycles': 0,
|
|
'perfect_moves_detected': 0,
|
|
'total_ticks_processed': 0,
|
|
'models_saved': 0,
|
|
'last_checkpoint': None
|
|
}
|
|
|
|
# Training intervals
|
|
self.rl_training_interval = 300 # 5 minutes
|
|
self.cnn_training_interval = 600 # 10 minutes
|
|
self.checkpoint_interval = 1800 # 30 minutes
|
|
|
|
logger.info("Continuous Training System initialized")
|
|
logger.info(f"RL training interval: {self.rl_training_interval}s")
|
|
logger.info(f"CNN training interval: {self.cnn_training_interval}s")
|
|
logger.info(f"Checkpoint interval: {self.checkpoint_interval}s")
|
|
|
|
async def start(self, run_dashboard: bool = True):
|
|
"""Start the continuous training system"""
|
|
logger.info("Starting Continuous Training System...")
|
|
self.running = True
|
|
self.training_stats['start_time'] = datetime.now()
|
|
|
|
try:
|
|
# Start DataProvider streaming
|
|
logger.info("Starting DataProvider real-time streaming...")
|
|
await self.data_provider.start_real_time_streaming()
|
|
|
|
# Subscribe to tick data for training
|
|
subscriber_id = self.data_provider.subscribe_to_ticks(
|
|
callback=self._handle_training_tick,
|
|
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
subscriber_name="ContinuousTraining"
|
|
)
|
|
logger.info(f"Subscribed to training tick stream: {subscriber_id}")
|
|
|
|
# Start training threads
|
|
training_tasks = [
|
|
asyncio.create_task(self._rl_training_loop()),
|
|
asyncio.create_task(self._cnn_training_loop()),
|
|
asyncio.create_task(self._checkpoint_loop()),
|
|
asyncio.create_task(self._monitoring_loop())
|
|
]
|
|
|
|
# Start dashboard if requested
|
|
if run_dashboard:
|
|
dashboard_task = asyncio.create_task(self._run_dashboard())
|
|
training_tasks.append(dashboard_task)
|
|
|
|
logger.info("All training components started successfully")
|
|
|
|
# Wait for shutdown signal
|
|
await self._wait_for_shutdown()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in continuous training system: {e}")
|
|
raise
|
|
finally:
|
|
await self.stop()
|
|
|
|
def _handle_training_tick(self, tick: MarketTick):
|
|
"""Handle incoming tick data for training"""
|
|
try:
|
|
self.training_stats['total_ticks_processed'] += 1
|
|
|
|
# Process tick through orchestrator for RL training
|
|
if self.orchestrator and hasattr(self.orchestrator, 'process_tick'):
|
|
self.orchestrator.process_tick(tick)
|
|
|
|
# Log every 1000 ticks
|
|
if self.training_stats['total_ticks_processed'] % 1000 == 0:
|
|
logger.info(f"Processed {self.training_stats['total_ticks_processed']} training ticks")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error processing training tick: {e}")
|
|
|
|
async def _rl_training_loop(self):
|
|
"""Continuous RL training loop"""
|
|
logger.info("Starting RL training loop...")
|
|
|
|
while self.running:
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Perform RL training cycle
|
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
logger.info("Starting RL training cycle...")
|
|
|
|
# Get recent market data for training
|
|
training_data = self._prepare_rl_training_data()
|
|
|
|
if training_data is not None:
|
|
# Train RL agent
|
|
training_results = await self._train_rl_agent(training_data)
|
|
|
|
if training_results:
|
|
self.training_stats['rl_training_cycles'] += 1
|
|
logger.info(f"RL training cycle {self.training_stats['rl_training_cycles']} completed")
|
|
logger.info(f"Training results: {training_results}")
|
|
else:
|
|
logger.warning("No training data available for RL agent")
|
|
|
|
# Wait for next training cycle
|
|
elapsed = time.time() - start_time
|
|
sleep_time = max(0, self.rl_training_interval - elapsed)
|
|
await asyncio.sleep(sleep_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in RL training loop: {e}")
|
|
await asyncio.sleep(60) # Wait before retrying
|
|
|
|
async def _cnn_training_loop(self):
|
|
"""Continuous CNN training loop"""
|
|
logger.info("Starting CNN training loop...")
|
|
|
|
while self.running:
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Perform CNN training cycle
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
logger.info("Starting CNN training cycle...")
|
|
|
|
# Detect perfect moves for CNN training
|
|
perfect_moves = self._detect_perfect_moves()
|
|
|
|
if perfect_moves:
|
|
self.training_stats['perfect_moves_detected'] += len(perfect_moves)
|
|
|
|
# Train CNN with perfect moves
|
|
training_results = await self._train_cnn_model(perfect_moves)
|
|
|
|
if training_results:
|
|
self.training_stats['cnn_training_cycles'] += 1
|
|
logger.info(f"CNN training cycle {self.training_stats['cnn_training_cycles']} completed")
|
|
logger.info(f"Perfect moves processed: {len(perfect_moves)}")
|
|
else:
|
|
logger.info("No perfect moves detected for CNN training")
|
|
|
|
# Wait for next training cycle
|
|
elapsed = time.time() - start_time
|
|
sleep_time = max(0, self.cnn_training_interval - elapsed)
|
|
await asyncio.sleep(sleep_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training loop: {e}")
|
|
await asyncio.sleep(60) # Wait before retrying
|
|
|
|
async def _checkpoint_loop(self):
|
|
"""Automatic model checkpointing loop"""
|
|
logger.info("Starting checkpoint loop...")
|
|
|
|
while self.running:
|
|
try:
|
|
await asyncio.sleep(self.checkpoint_interval)
|
|
|
|
logger.info("Creating model checkpoints...")
|
|
|
|
# Save RL model
|
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
rl_checkpoint = await self._save_rl_checkpoint()
|
|
if rl_checkpoint:
|
|
logger.info(f"RL checkpoint saved: {rl_checkpoint}")
|
|
|
|
# Save CNN model
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
cnn_checkpoint = await self._save_cnn_checkpoint()
|
|
if cnn_checkpoint:
|
|
logger.info(f"CNN checkpoint saved: {cnn_checkpoint}")
|
|
|
|
self.training_stats['models_saved'] += 1
|
|
self.training_stats['last_checkpoint'] = datetime.now()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in checkpoint loop: {e}")
|
|
|
|
async def _monitoring_loop(self):
|
|
"""System monitoring and performance tracking loop"""
|
|
logger.info("Starting monitoring loop...")
|
|
|
|
while self.running:
|
|
try:
|
|
await asyncio.sleep(300) # Monitor every 5 minutes
|
|
|
|
# Log system statistics
|
|
uptime = datetime.now() - self.training_stats['start_time']
|
|
|
|
logger.info("=== CONTINUOUS TRAINING SYSTEM STATUS ===")
|
|
logger.info(f"Uptime: {uptime}")
|
|
logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
|
|
logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
|
|
logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
|
|
logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
|
|
logger.info(f"Models saved: {self.training_stats['models_saved']}")
|
|
|
|
# DataProvider statistics
|
|
if hasattr(self.data_provider, 'get_subscriber_stats'):
|
|
subscriber_stats = self.data_provider.get_subscriber_stats()
|
|
logger.info(f"Active subscribers: {subscriber_stats.get('active_subscribers', 0)}")
|
|
logger.info(f"Total ticks distributed: {subscriber_stats.get('distribution_stats', {}).get('total_ticks_distributed', 0)}")
|
|
|
|
# Orchestrator performance
|
|
if hasattr(self.orchestrator, 'get_performance_metrics'):
|
|
perf_metrics = self.orchestrator.get_performance_metrics()
|
|
logger.info(f"Orchestrator performance: {perf_metrics}")
|
|
|
|
logger.info("==========================================")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in monitoring loop: {e}")
|
|
|
|
async def _run_dashboard(self):
|
|
"""Run the dashboard in a separate thread"""
|
|
try:
|
|
logger.info("Starting live trading dashboard...")
|
|
|
|
def run_dashboard():
|
|
self.dashboard = RealTimeScalpingDashboard(
|
|
data_provider=self.data_provider,
|
|
orchestrator=self.orchestrator
|
|
)
|
|
self.dashboard.run(host='127.0.0.1', port=8051, debug=False)
|
|
|
|
dashboard_thread = Thread(target=run_dashboard, daemon=True)
|
|
dashboard_thread.start()
|
|
|
|
logger.info("Dashboard started at http://127.0.0.1:8051")
|
|
|
|
# Keep dashboard thread alive
|
|
while self.running:
|
|
await asyncio.sleep(10)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running dashboard: {e}")
|
|
|
|
def _prepare_rl_training_data(self) -> Dict[str, Any]:
|
|
"""Prepare training data for RL agent"""
|
|
try:
|
|
# Get recent market data from DataProvider
|
|
eth_data = self.data_provider.get_latest_candles('ETH/USDT', '1m', limit=1000)
|
|
btc_data = self.data_provider.get_latest_candles('BTC/USDT', '1m', limit=1000)
|
|
|
|
if eth_data is not None and not eth_data.empty:
|
|
return {
|
|
'eth_data': eth_data,
|
|
'btc_data': btc_data,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error preparing RL training data: {e}")
|
|
return None
|
|
|
|
def _detect_perfect_moves(self) -> list:
|
|
"""Detect perfect trading moves for CNN training"""
|
|
try:
|
|
# Get recent tick data
|
|
recent_ticks = self.data_provider.get_recent_ticks('ETHUSDT', count=500)
|
|
|
|
if not recent_ticks:
|
|
return []
|
|
|
|
# Simple perfect move detection (can be enhanced)
|
|
perfect_moves = []
|
|
|
|
for i in range(1, len(recent_ticks) - 1):
|
|
prev_tick = recent_ticks[i-1]
|
|
curr_tick = recent_ticks[i]
|
|
next_tick = recent_ticks[i+1]
|
|
|
|
# Detect significant price movements
|
|
price_change = (next_tick.price - curr_tick.price) / curr_tick.price
|
|
|
|
if abs(price_change) > 0.001: # 0.1% movement
|
|
perfect_moves.append({
|
|
'timestamp': curr_tick.timestamp,
|
|
'price': curr_tick.price,
|
|
'action': 'BUY' if price_change > 0 else 'SELL',
|
|
'confidence': min(abs(price_change) * 100, 1.0)
|
|
})
|
|
|
|
return perfect_moves[-10:] # Return last 10 perfect moves
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error detecting perfect moves: {e}")
|
|
return []
|
|
|
|
async def _train_rl_agent(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Train the RL agent with market data"""
|
|
try:
|
|
# Placeholder for RL training logic
|
|
# This would integrate with the actual RL agent
|
|
|
|
logger.info("Training RL agent with market data...")
|
|
|
|
# Simulate training time
|
|
await asyncio.sleep(1)
|
|
|
|
return {
|
|
'loss': 0.05,
|
|
'reward': 0.75,
|
|
'episodes': 100
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training RL agent: {e}")
|
|
return None
|
|
|
|
async def _train_cnn_model(self, perfect_moves: list) -> Dict[str, Any]:
|
|
"""Train the CNN model with perfect moves"""
|
|
try:
|
|
# Placeholder for CNN training logic
|
|
# This would integrate with the actual CNN model
|
|
|
|
logger.info(f"Training CNN model with {len(perfect_moves)} perfect moves...")
|
|
|
|
# Simulate training time
|
|
await asyncio.sleep(2)
|
|
|
|
return {
|
|
'accuracy': 0.92,
|
|
'loss': 0.08,
|
|
'perfect_moves_processed': len(perfect_moves)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training CNN model: {e}")
|
|
return None
|
|
|
|
async def _save_rl_checkpoint(self) -> str:
|
|
"""Save RL model checkpoint"""
|
|
try:
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
checkpoint_path = f"models/rl/checkpoint_rl_{timestamp}.pt"
|
|
|
|
# Placeholder for actual model saving
|
|
logger.info(f"Saving RL checkpoint to {checkpoint_path}")
|
|
|
|
return checkpoint_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving RL checkpoint: {e}")
|
|
return None
|
|
|
|
async def _save_cnn_checkpoint(self) -> str:
|
|
"""Save CNN model checkpoint"""
|
|
try:
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
checkpoint_path = f"models/cnn/checkpoint_cnn_{timestamp}.pt"
|
|
|
|
# Placeholder for actual model saving
|
|
logger.info(f"Saving CNN checkpoint to {checkpoint_path}")
|
|
|
|
return checkpoint_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving CNN checkpoint: {e}")
|
|
return None
|
|
|
|
async def _wait_for_shutdown(self):
|
|
"""Wait for shutdown signal"""
|
|
def signal_handler(signum, frame):
|
|
logger.info(f"Received signal {signum}, shutting down...")
|
|
self.shutdown_event.set()
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
# Wait for shutdown event
|
|
while not self.shutdown_event.is_set():
|
|
await asyncio.sleep(1)
|
|
|
|
async def stop(self):
|
|
"""Stop the continuous training system"""
|
|
logger.info("Stopping Continuous Training System...")
|
|
self.running = False
|
|
|
|
try:
|
|
# Stop DataProvider streaming
|
|
if self.data_provider:
|
|
await self.data_provider.stop_real_time_streaming()
|
|
|
|
# Final checkpoint
|
|
logger.info("Creating final checkpoints...")
|
|
await self._save_rl_checkpoint()
|
|
await self._save_cnn_checkpoint()
|
|
|
|
# Log final statistics
|
|
uptime = datetime.now() - self.training_stats['start_time']
|
|
logger.info("=== FINAL TRAINING STATISTICS ===")
|
|
logger.info(f"Total uptime: {uptime}")
|
|
logger.info(f"RL training cycles: {self.training_stats['rl_training_cycles']}")
|
|
logger.info(f"CNN training cycles: {self.training_stats['cnn_training_cycles']}")
|
|
logger.info(f"Perfect moves detected: {self.training_stats['perfect_moves_detected']}")
|
|
logger.info(f"Total ticks processed: {self.training_stats['total_ticks_processed']}")
|
|
logger.info(f"Models saved: {self.training_stats['models_saved']}")
|
|
logger.info("=================================")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during shutdown: {e}")
|
|
|
|
logger.info("Continuous Training System stopped")
|
|
|
|
async def main():
|
|
"""Main entry point"""
|
|
logger.info("Starting Continuous Full Training System (RL + CNN)")
|
|
|
|
# Create and start the training system
|
|
training_system = ContinuousTrainingSystem()
|
|
|
|
try:
|
|
await training_system.start(run_dashboard=True)
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |