Files
gogo2/training_runner.py
Dobromir Popov 608da8233f main cleanup
2025-09-30 23:56:36 +03:00

486 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Unified Training Runner
CRITICAL POLICY: NO SYNTHETIC DATA ALLOWED
This module MUST ONLY use real market data from exchanges.
NEVER use np.random.*, mock/fake/synthetic data, or placeholder values.
If data is unavailable: return None/0/empty, log errors, raise exceptions.
See: reports/REAL_MARKET_DATA_POLICY.md
Consolidated training system supporting both realtime and backtesting modes.
Modes:
1. REALTIME: Live market data training with continuous learning
2. BACKTEST: Historical data with sliding window simulation for fast training
Features:
- Multi-horizon predictions (1m, 5m, 15m, 60m)
- CNN, DQN, and COB RL model training
- Checkpoint management with model rotation
- Performance tracking and reporting
- Resumable training sessions
"""
import logging
import time
import json
import argparse
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Any, Optional
from collections import deque
import asyncio
# Core components
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.multi_horizon_backtester import MultiHorizonBacktester
from core.multi_horizon_prediction_manager import MultiHorizonPredictionManager
from core.prediction_snapshot_storage import PredictionSnapshotStorage
from core.multi_horizon_trainer import MultiHorizonTrainer
# Model management
from NN.training.model_manager import create_model_manager
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/training.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class UnifiedTrainingRunner:
"""Unified training system supporting both realtime and backtesting modes"""
def __init__(self, mode: str = "realtime", symbol: str = "ETH/USDT"):
"""
Initialize the unified training runner
Args:
mode: "realtime" for live training or "backtest" for historical training
symbol: Trading symbol to train on
"""
self.mode = mode
self.symbol = symbol
self.start_time = datetime.now()
logger.info(f"Initializing Unified Training Runner - Mode: {mode.upper()}")
# Initialize core components
self.data_provider = DataProvider()
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=True
)
# Initialize training components
self.backtester = MultiHorizonBacktester(self.data_provider)
self.prediction_manager = MultiHorizonPredictionManager(
data_provider=self.data_provider
)
self.snapshot_storage = PredictionSnapshotStorage()
self.trainer = MultiHorizonTrainer(
orchestrator=self.orchestrator,
snapshot_storage=self.snapshot_storage
)
# Initialize enhanced real-time training (used in both modes)
self.enhanced_training = None
if hasattr(self.orchestrator, 'enhanced_training_system'):
self.enhanced_training = self.orchestrator.enhanced_training_system
# Model checkpoint manager
self.checkpoint_manager = create_model_manager()
# Training configuration
self.config = {
'realtime': {
'checkpoint_interval_minutes': 30,
'backtest_interval_minutes': 60,
'performance_check_minutes': 15
},
'backtest': {
'window_size_hours': 24,
'step_size_hours': 1,
'batch_size': 64,
'save_interval_hours': 2
}
}
# Performance tracking
self.metrics = {
'training_sessions': [],
'backtest_results': [],
'model_checkpoints': [],
'prediction_accuracy': deque(maxlen=1000),
'training_losses': {'cnn': [], 'dqn': [], 'cob_rl': []}
}
# Training state
self.is_running = False
self.progress_file = Path('training_progress.json')
logger.info(f"Unified Training Runner initialized for {symbol}")
logger.info(f"Mode: {mode}, Enhanced Training: {self.enhanced_training is not None}")
def run_realtime_training(self, duration_hours: Optional[float] = None):
"""
Run continuous real-time training on live market data
Args:
duration_hours: How long to train (None = indefinite)
"""
logger.info("=" * 70)
logger.info("STARTING REALTIME TRAINING")
logger.info("=" * 70)
logger.info(f"Duration: {'indefinite' if duration_hours is None else f'{duration_hours} hours'}")
self.is_running = True
config = self.config['realtime']
last_checkpoint = time.time()
last_backtest = time.time()
last_perf_check = time.time()
try:
# Start enhanced training if available
if self.enhanced_training and hasattr(self.orchestrator, 'start_enhanced_training'):
self.orchestrator.start_enhanced_training()
logger.info("Enhanced real-time training started")
# Start multi-horizon prediction and training
self.prediction_manager.start()
self.trainer.start()
logger.info("Multi-horizon prediction and training started")
while self.is_running:
current_time = time.time()
elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600
# Check duration limit
if duration_hours and elapsed_hours >= duration_hours:
logger.info(f"Training duration completed: {elapsed_hours:.1f} hours")
break
# Periodic checkpoint save
if current_time - last_checkpoint > config['checkpoint_interval_minutes'] * 60:
self._save_checkpoint()
last_checkpoint = current_time
# Periodic backtest validation
if current_time - last_backtest > config['backtest_interval_minutes'] * 60:
accuracy = self._run_backtest_validation()
if accuracy is not None:
self.metrics['prediction_accuracy'].append(accuracy)
logger.info(f"Backtest accuracy at {elapsed_hours:.1f}h: {accuracy:.3%}")
last_backtest = current_time
# Performance check
if current_time - last_perf_check > config['performance_check_minutes'] * 60:
self._log_performance_metrics()
last_perf_check = current_time
# Sleep to reduce CPU usage
time.sleep(60)
except KeyboardInterrupt:
logger.info("Training interrupted by user")
finally:
self._cleanup_training()
self._generate_final_report()
def run_backtest_training(self, start_date: datetime, end_date: datetime):
"""
Run fast backtesting with sliding window for bulk training
Args:
start_date: Start date for backtesting
end_date: End date for backtesting
"""
logger.info("=" * 70)
logger.info("STARTING BACKTEST TRAINING")
logger.info("=" * 70)
logger.info(f"Period: {start_date} to {end_date}")
config = self.config['backtest']
window_hours = config['window_size_hours']
step_hours = config['step_size_hours']
current_date = start_date
batch_count = 0
total_samples = 0
try:
while current_date < end_date:
window_end = current_date + timedelta(hours=window_hours)
if window_end > end_date:
break
batch_count += 1
logger.info(f"Batch {batch_count}: {current_date} to {window_end}")
# Fetch historical data for window
data = self._fetch_window_data(current_date, window_end)
if data and len(data) > 0:
# Simulate real-time data flow through sliding window
samples_trained = self._train_on_window(data)
total_samples += samples_trained
logger.info(f"Trained on {samples_trained} samples in window")
# Save checkpoint periodically
elapsed_hours = (window_end - start_date).total_seconds() / 3600
if elapsed_hours % config['save_interval_hours'] == 0:
self._save_checkpoint()
logger.info(f"Checkpoint saved at {elapsed_hours:.1f}h")
# Move window forward
current_date += timedelta(hours=step_hours)
logger.info(f"Backtest training complete: {batch_count} batches, {total_samples} samples")
except Exception as e:
logger.error(f"Error in backtest training: {e}")
raise
finally:
self._generate_final_report()
def _fetch_window_data(self, start: datetime, end: datetime) -> List[Dict]:
"""Fetch historical data for a time window"""
try:
# Fetch from data provider with real market data
data = self.data_provider.get_historical_data(
symbol=self.symbol,
timeframe='1m',
start_time=start,
end_time=end
)
if data is None or len(data) == 0:
logger.warning(f"No data available for {start} to {end}")
return []
return data
except Exception as e:
logger.error(f"Error fetching window data: {e}")
return []
def _train_on_window(self, data: List[Dict]) -> int:
"""
Train models on a sliding window of data
Args:
data: List of market data points
Returns:
Number of samples trained on
"""
samples_trained = 0
# Simulate real-time flow through data
for i in range(len(data) - 1):
current = data[i]
next_data = data[i + 1]
# Create prediction snapshot
snapshot = {
'timestamp': current.get('timestamp'),
'price': current.get('close', 0),
'volume': current.get('volume', 0),
'symbol': self.symbol
}
# Store snapshot for later training
self.snapshot_storage.store_snapshot(snapshot)
# When we have outcome, train the models
if i > 0: # Need previous snapshot for outcome
prev_snapshot = data[i - 1]
outcome = {
'actual_price': current.get('close', 0),
'timestamp': current.get('timestamp')
}
# Train via multi-horizon trainer
self.trainer.train_on_outcome(prev_snapshot, outcome)
samples_trained += 1
return samples_trained
def _run_backtest_validation(self) -> Optional[float]:
"""Run backtest on recent data to validate model performance"""
try:
end_date = datetime.now()
start_date = end_date - timedelta(hours=24)
results = self.backtester.run_backtest(
symbol=self.symbol,
start_date=start_date,
end_date=end_date,
horizons=[1, 5, 15, 60] # minutes
)
if results and 'accuracy' in results:
return results['accuracy']
return None
except Exception as e:
logger.error(f"Error in backtest validation: {e}")
return None
def _save_checkpoint(self):
"""Save model checkpoints with rotation"""
try:
checkpoint_data = {
'timestamp': datetime.now().isoformat(),
'mode': self.mode,
'elapsed_hours': (datetime.now() - self.start_time).total_seconds() / 3600,
'metrics': {
'prediction_accuracy': list(self.metrics['prediction_accuracy'])[-10:],
'total_training_samples': sum(
len(losses) for losses in self.metrics['training_losses'].values()
)
}
}
# Use model manager for checkpoint rotation (keeps best 5)
self.checkpoint_manager.save_checkpoint(
model=self.orchestrator,
metadata=checkpoint_data
)
self.metrics['model_checkpoints'].append(checkpoint_data)
logger.info("Checkpoint saved successfully")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def _log_performance_metrics(self):
"""Log current performance metrics"""
elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600
avg_accuracy = 0
if self.metrics['prediction_accuracy']:
avg_accuracy = sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy'])
logger.info("=" * 50)
logger.info(f"Performance Metrics @ {elapsed_hours:.1f}h")
logger.info(f" Avg Prediction Accuracy: {avg_accuracy:.3%}")
logger.info(f" Total Checkpoints: {len(self.metrics['model_checkpoints'])}")
logger.info(f" CNN Training Samples: {len(self.metrics['training_losses']['cnn'])}")
logger.info(f" DQN Training Samples: {len(self.metrics['training_losses']['dqn'])}")
logger.info("=" * 50)
def _cleanup_training(self):
"""Clean up training resources"""
logger.info("Cleaning up training resources...")
# Stop prediction and training
if hasattr(self.prediction_manager, 'stop'):
self.prediction_manager.stop()
if hasattr(self.trainer, 'stop'):
self.trainer.stop()
# Save final checkpoint
self._save_checkpoint()
logger.info("Training cleanup complete")
def _generate_final_report(self):
"""Generate final training report"""
report = {
'mode': self.mode,
'symbol': self.symbol,
'start_time': self.start_time.isoformat(),
'end_time': datetime.now().isoformat(),
'duration_hours': (datetime.now() - self.start_time).total_seconds() / 3600,
'metrics': {
'total_checkpoints': len(self.metrics['model_checkpoints']),
'total_backtest_runs': len(self.metrics['backtest_results']),
'final_accuracy': list(self.metrics['prediction_accuracy'])[-1] if self.metrics['prediction_accuracy'] else 0,
'avg_accuracy': sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy']) if self.metrics['prediction_accuracy'] else 0
}
}
report_file = Path(f'training_report_{self.mode}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json')
with open(report_file, 'w') as f:
json.dump(report, f, indent=2)
logger.info("=" * 70)
logger.info("TRAINING COMPLETE")
logger.info("=" * 70)
logger.info(f"Mode: {self.mode}")
logger.info(f"Duration: {report['duration_hours']:.2f} hours")
logger.info(f"Final Accuracy: {report['metrics']['final_accuracy']:.3%}")
logger.info(f"Avg Accuracy: {report['metrics']['avg_accuracy']:.3%}")
logger.info(f"Report saved to: {report_file}")
logger.info("=" * 70)
def main():
"""Main entry point for training runner"""
parser = argparse.ArgumentParser(description="Unified Training Runner")
parser.add_argument(
'--mode',
type=str,
choices=['realtime', 'backtest'],
default='realtime',
help='Training mode: realtime or backtest'
)
parser.add_argument(
'--symbol',
type=str,
default='ETH/USDT',
help='Trading symbol'
)
parser.add_argument(
'--duration',
type=float,
default=None,
help='Training duration in hours (realtime mode only)'
)
parser.add_argument(
'--start-date',
type=str,
default=None,
help='Start date for backtest (YYYY-MM-DD)'
)
parser.add_argument(
'--end-date',
type=str,
default=None,
help='End date for backtest (YYYY-MM-DD)'
)
args = parser.parse_args()
# Create training runner
runner = UnifiedTrainingRunner(mode=args.mode, symbol=args.symbol)
if args.mode == 'realtime':
runner.run_realtime_training(duration_hours=args.duration)
else: # backtest
if not args.start_date or not args.end_date:
logger.error("Backtest mode requires --start-date and --end-date")
return
start = datetime.strptime(args.start_date, '%Y-%m-%d')
end = datetime.strptime(args.end_date, '%Y-%m-%d')
runner.run_backtest_training(start_date=start, end_date=end)
if __name__ == '__main__':
main()