main cleanup
This commit is contained in:
485
training_runner.py
Normal file
485
training_runner.py
Normal file
@@ -0,0 +1,485 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified Training Runner
|
||||
|
||||
CRITICAL POLICY: NO SYNTHETIC DATA ALLOWED
|
||||
This module MUST ONLY use real market data from exchanges.
|
||||
NEVER use np.random.*, mock/fake/synthetic data, or placeholder values.
|
||||
If data is unavailable: return None/0/empty, log errors, raise exceptions.
|
||||
See: reports/REAL_MARKET_DATA_POLICY.md
|
||||
|
||||
Consolidated training system supporting both realtime and backtesting modes.
|
||||
|
||||
Modes:
|
||||
1. REALTIME: Live market data training with continuous learning
|
||||
2. BACKTEST: Historical data with sliding window simulation for fast training
|
||||
|
||||
Features:
|
||||
- Multi-horizon predictions (1m, 5m, 15m, 60m)
|
||||
- CNN, DQN, and COB RL model training
|
||||
- Checkpoint management with model rotation
|
||||
- Performance tracking and reporting
|
||||
- Resumable training sessions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
from collections import deque
|
||||
import asyncio
|
||||
|
||||
# Core components
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.multi_horizon_backtester import MultiHorizonBacktester
|
||||
from core.multi_horizon_prediction_manager import MultiHorizonPredictionManager
|
||||
from core.prediction_snapshot_storage import PredictionSnapshotStorage
|
||||
from core.multi_horizon_trainer import MultiHorizonTrainer
|
||||
|
||||
# Model management
|
||||
from NN.training.model_manager import create_model_manager
|
||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/training.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnifiedTrainingRunner:
|
||||
"""Unified training system supporting both realtime and backtesting modes"""
|
||||
|
||||
def __init__(self, mode: str = "realtime", symbol: str = "ETH/USDT"):
|
||||
"""
|
||||
Initialize the unified training runner
|
||||
|
||||
Args:
|
||||
mode: "realtime" for live training or "backtest" for historical training
|
||||
symbol: Trading symbol to train on
|
||||
"""
|
||||
self.mode = mode
|
||||
self.symbol = symbol
|
||||
self.start_time = datetime.now()
|
||||
|
||||
logger.info(f"Initializing Unified Training Runner - Mode: {mode.upper()}")
|
||||
|
||||
# Initialize core components
|
||||
self.data_provider = DataProvider()
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
|
||||
# Initialize training components
|
||||
self.backtester = MultiHorizonBacktester(self.data_provider)
|
||||
self.prediction_manager = MultiHorizonPredictionManager(
|
||||
data_provider=self.data_provider
|
||||
)
|
||||
self.snapshot_storage = PredictionSnapshotStorage()
|
||||
self.trainer = MultiHorizonTrainer(
|
||||
orchestrator=self.orchestrator,
|
||||
snapshot_storage=self.snapshot_storage
|
||||
)
|
||||
|
||||
# Initialize enhanced real-time training (used in both modes)
|
||||
self.enhanced_training = None
|
||||
if hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||
self.enhanced_training = self.orchestrator.enhanced_training_system
|
||||
|
||||
# Model checkpoint manager
|
||||
self.checkpoint_manager = create_model_manager()
|
||||
|
||||
# Training configuration
|
||||
self.config = {
|
||||
'realtime': {
|
||||
'checkpoint_interval_minutes': 30,
|
||||
'backtest_interval_minutes': 60,
|
||||
'performance_check_minutes': 15
|
||||
},
|
||||
'backtest': {
|
||||
'window_size_hours': 24,
|
||||
'step_size_hours': 1,
|
||||
'batch_size': 64,
|
||||
'save_interval_hours': 2
|
||||
}
|
||||
}
|
||||
|
||||
# Performance tracking
|
||||
self.metrics = {
|
||||
'training_sessions': [],
|
||||
'backtest_results': [],
|
||||
'model_checkpoints': [],
|
||||
'prediction_accuracy': deque(maxlen=1000),
|
||||
'training_losses': {'cnn': [], 'dqn': [], 'cob_rl': []}
|
||||
}
|
||||
|
||||
# Training state
|
||||
self.is_running = False
|
||||
self.progress_file = Path('training_progress.json')
|
||||
|
||||
logger.info(f"Unified Training Runner initialized for {symbol}")
|
||||
logger.info(f"Mode: {mode}, Enhanced Training: {self.enhanced_training is not None}")
|
||||
|
||||
def run_realtime_training(self, duration_hours: Optional[float] = None):
|
||||
"""
|
||||
Run continuous real-time training on live market data
|
||||
|
||||
Args:
|
||||
duration_hours: How long to train (None = indefinite)
|
||||
"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTING REALTIME TRAINING")
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Duration: {'indefinite' if duration_hours is None else f'{duration_hours} hours'}")
|
||||
|
||||
self.is_running = True
|
||||
config = self.config['realtime']
|
||||
|
||||
last_checkpoint = time.time()
|
||||
last_backtest = time.time()
|
||||
last_perf_check = time.time()
|
||||
|
||||
try:
|
||||
# Start enhanced training if available
|
||||
if self.enhanced_training and hasattr(self.orchestrator, 'start_enhanced_training'):
|
||||
self.orchestrator.start_enhanced_training()
|
||||
logger.info("Enhanced real-time training started")
|
||||
|
||||
# Start multi-horizon prediction and training
|
||||
self.prediction_manager.start()
|
||||
self.trainer.start()
|
||||
logger.info("Multi-horizon prediction and training started")
|
||||
|
||||
while self.is_running:
|
||||
current_time = time.time()
|
||||
elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600
|
||||
|
||||
# Check duration limit
|
||||
if duration_hours and elapsed_hours >= duration_hours:
|
||||
logger.info(f"Training duration completed: {elapsed_hours:.1f} hours")
|
||||
break
|
||||
|
||||
# Periodic checkpoint save
|
||||
if current_time - last_checkpoint > config['checkpoint_interval_minutes'] * 60:
|
||||
self._save_checkpoint()
|
||||
last_checkpoint = current_time
|
||||
|
||||
# Periodic backtest validation
|
||||
if current_time - last_backtest > config['backtest_interval_minutes'] * 60:
|
||||
accuracy = self._run_backtest_validation()
|
||||
if accuracy is not None:
|
||||
self.metrics['prediction_accuracy'].append(accuracy)
|
||||
logger.info(f"Backtest accuracy at {elapsed_hours:.1f}h: {accuracy:.3%}")
|
||||
last_backtest = current_time
|
||||
|
||||
# Performance check
|
||||
if current_time - last_perf_check > config['performance_check_minutes'] * 60:
|
||||
self._log_performance_metrics()
|
||||
last_perf_check = current_time
|
||||
|
||||
# Sleep to reduce CPU usage
|
||||
time.sleep(60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
finally:
|
||||
self._cleanup_training()
|
||||
self._generate_final_report()
|
||||
|
||||
def run_backtest_training(self, start_date: datetime, end_date: datetime):
|
||||
"""
|
||||
Run fast backtesting with sliding window for bulk training
|
||||
|
||||
Args:
|
||||
start_date: Start date for backtesting
|
||||
end_date: End date for backtesting
|
||||
"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTING BACKTEST TRAINING")
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Period: {start_date} to {end_date}")
|
||||
|
||||
config = self.config['backtest']
|
||||
window_hours = config['window_size_hours']
|
||||
step_hours = config['step_size_hours']
|
||||
|
||||
current_date = start_date
|
||||
batch_count = 0
|
||||
total_samples = 0
|
||||
|
||||
try:
|
||||
while current_date < end_date:
|
||||
window_end = current_date + timedelta(hours=window_hours)
|
||||
|
||||
if window_end > end_date:
|
||||
break
|
||||
|
||||
batch_count += 1
|
||||
logger.info(f"Batch {batch_count}: {current_date} to {window_end}")
|
||||
|
||||
# Fetch historical data for window
|
||||
data = self._fetch_window_data(current_date, window_end)
|
||||
|
||||
if data and len(data) > 0:
|
||||
# Simulate real-time data flow through sliding window
|
||||
samples_trained = self._train_on_window(data)
|
||||
total_samples += samples_trained
|
||||
|
||||
logger.info(f"Trained on {samples_trained} samples in window")
|
||||
|
||||
# Save checkpoint periodically
|
||||
elapsed_hours = (window_end - start_date).total_seconds() / 3600
|
||||
if elapsed_hours % config['save_interval_hours'] == 0:
|
||||
self._save_checkpoint()
|
||||
logger.info(f"Checkpoint saved at {elapsed_hours:.1f}h")
|
||||
|
||||
# Move window forward
|
||||
current_date += timedelta(hours=step_hours)
|
||||
|
||||
logger.info(f"Backtest training complete: {batch_count} batches, {total_samples} samples")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in backtest training: {e}")
|
||||
raise
|
||||
finally:
|
||||
self._generate_final_report()
|
||||
|
||||
def _fetch_window_data(self, start: datetime, end: datetime) -> List[Dict]:
|
||||
"""Fetch historical data for a time window"""
|
||||
try:
|
||||
# Fetch from data provider with real market data
|
||||
data = self.data_provider.get_historical_data(
|
||||
symbol=self.symbol,
|
||||
timeframe='1m',
|
||||
start_time=start,
|
||||
end_time=end
|
||||
)
|
||||
|
||||
if data is None or len(data) == 0:
|
||||
logger.warning(f"No data available for {start} to {end}")
|
||||
return []
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching window data: {e}")
|
||||
return []
|
||||
|
||||
def _train_on_window(self, data: List[Dict]) -> int:
|
||||
"""
|
||||
Train models on a sliding window of data
|
||||
|
||||
Args:
|
||||
data: List of market data points
|
||||
|
||||
Returns:
|
||||
Number of samples trained on
|
||||
"""
|
||||
samples_trained = 0
|
||||
|
||||
# Simulate real-time flow through data
|
||||
for i in range(len(data) - 1):
|
||||
current = data[i]
|
||||
next_data = data[i + 1]
|
||||
|
||||
# Create prediction snapshot
|
||||
snapshot = {
|
||||
'timestamp': current.get('timestamp'),
|
||||
'price': current.get('close', 0),
|
||||
'volume': current.get('volume', 0),
|
||||
'symbol': self.symbol
|
||||
}
|
||||
|
||||
# Store snapshot for later training
|
||||
self.snapshot_storage.store_snapshot(snapshot)
|
||||
|
||||
# When we have outcome, train the models
|
||||
if i > 0: # Need previous snapshot for outcome
|
||||
prev_snapshot = data[i - 1]
|
||||
outcome = {
|
||||
'actual_price': current.get('close', 0),
|
||||
'timestamp': current.get('timestamp')
|
||||
}
|
||||
|
||||
# Train via multi-horizon trainer
|
||||
self.trainer.train_on_outcome(prev_snapshot, outcome)
|
||||
samples_trained += 1
|
||||
|
||||
return samples_trained
|
||||
|
||||
def _run_backtest_validation(self) -> Optional[float]:
|
||||
"""Run backtest on recent data to validate model performance"""
|
||||
try:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(hours=24)
|
||||
|
||||
results = self.backtester.run_backtest(
|
||||
symbol=self.symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
horizons=[1, 5, 15, 60] # minutes
|
||||
)
|
||||
|
||||
if results and 'accuracy' in results:
|
||||
return results['accuracy']
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in backtest validation: {e}")
|
||||
return None
|
||||
|
||||
def _save_checkpoint(self):
|
||||
"""Save model checkpoints with rotation"""
|
||||
try:
|
||||
checkpoint_data = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'mode': self.mode,
|
||||
'elapsed_hours': (datetime.now() - self.start_time).total_seconds() / 3600,
|
||||
'metrics': {
|
||||
'prediction_accuracy': list(self.metrics['prediction_accuracy'])[-10:],
|
||||
'total_training_samples': sum(
|
||||
len(losses) for losses in self.metrics['training_losses'].values()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
# Use model manager for checkpoint rotation (keeps best 5)
|
||||
self.checkpoint_manager.save_checkpoint(
|
||||
model=self.orchestrator,
|
||||
metadata=checkpoint_data
|
||||
)
|
||||
|
||||
self.metrics['model_checkpoints'].append(checkpoint_data)
|
||||
logger.info("Checkpoint saved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def _log_performance_metrics(self):
|
||||
"""Log current performance metrics"""
|
||||
elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600
|
||||
|
||||
avg_accuracy = 0
|
||||
if self.metrics['prediction_accuracy']:
|
||||
avg_accuracy = sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy'])
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Performance Metrics @ {elapsed_hours:.1f}h")
|
||||
logger.info(f" Avg Prediction Accuracy: {avg_accuracy:.3%}")
|
||||
logger.info(f" Total Checkpoints: {len(self.metrics['model_checkpoints'])}")
|
||||
logger.info(f" CNN Training Samples: {len(self.metrics['training_losses']['cnn'])}")
|
||||
logger.info(f" DQN Training Samples: {len(self.metrics['training_losses']['dqn'])}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
def _cleanup_training(self):
|
||||
"""Clean up training resources"""
|
||||
logger.info("Cleaning up training resources...")
|
||||
|
||||
# Stop prediction and training
|
||||
if hasattr(self.prediction_manager, 'stop'):
|
||||
self.prediction_manager.stop()
|
||||
if hasattr(self.trainer, 'stop'):
|
||||
self.trainer.stop()
|
||||
|
||||
# Save final checkpoint
|
||||
self._save_checkpoint()
|
||||
|
||||
logger.info("Training cleanup complete")
|
||||
|
||||
def _generate_final_report(self):
|
||||
"""Generate final training report"""
|
||||
report = {
|
||||
'mode': self.mode,
|
||||
'symbol': self.symbol,
|
||||
'start_time': self.start_time.isoformat(),
|
||||
'end_time': datetime.now().isoformat(),
|
||||
'duration_hours': (datetime.now() - self.start_time).total_seconds() / 3600,
|
||||
'metrics': {
|
||||
'total_checkpoints': len(self.metrics['model_checkpoints']),
|
||||
'total_backtest_runs': len(self.metrics['backtest_results']),
|
||||
'final_accuracy': list(self.metrics['prediction_accuracy'])[-1] if self.metrics['prediction_accuracy'] else 0,
|
||||
'avg_accuracy': sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy']) if self.metrics['prediction_accuracy'] else 0
|
||||
}
|
||||
}
|
||||
|
||||
report_file = Path(f'training_report_{self.mode}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json')
|
||||
with open(report_file, 'w') as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRAINING COMPLETE")
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Mode: {self.mode}")
|
||||
logger.info(f"Duration: {report['duration_hours']:.2f} hours")
|
||||
logger.info(f"Final Accuracy: {report['metrics']['final_accuracy']:.3%}")
|
||||
logger.info(f"Avg Accuracy: {report['metrics']['avg_accuracy']:.3%}")
|
||||
logger.info(f"Report saved to: {report_file}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for training runner"""
|
||||
parser = argparse.ArgumentParser(description="Unified Training Runner")
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
choices=['realtime', 'backtest'],
|
||||
default='realtime',
|
||||
help='Training mode: realtime or backtest'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--symbol',
|
||||
type=str,
|
||||
default='ETH/USDT',
|
||||
help='Trading symbol'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--duration',
|
||||
type=float,
|
||||
default=None,
|
||||
help='Training duration in hours (realtime mode only)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--start-date',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Start date for backtest (YYYY-MM-DD)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--end-date',
|
||||
type=str,
|
||||
default=None,
|
||||
help='End date for backtest (YYYY-MM-DD)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create training runner
|
||||
runner = UnifiedTrainingRunner(mode=args.mode, symbol=args.symbol)
|
||||
|
||||
if args.mode == 'realtime':
|
||||
runner.run_realtime_training(duration_hours=args.duration)
|
||||
else: # backtest
|
||||
if not args.start_date or not args.end_date:
|
||||
logger.error("Backtest mode requires --start-date and --end-date")
|
||||
return
|
||||
|
||||
start = datetime.strptime(args.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(args.end_date, '%Y-%m-%d')
|
||||
|
||||
runner.run_backtest_training(start_date=start, end_date=end)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user