486 lines
18 KiB
Python
486 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unified Training Runner
|
|
|
|
CRITICAL POLICY: NO SYNTHETIC DATA ALLOWED
|
|
This module MUST ONLY use real market data from exchanges.
|
|
NEVER use np.random.*, mock/fake/synthetic data, or placeholder values.
|
|
If data is unavailable: return None/0/empty, log errors, raise exceptions.
|
|
See: reports/REAL_MARKET_DATA_POLICY.md
|
|
|
|
Consolidated training system supporting both realtime and backtesting modes.
|
|
|
|
Modes:
|
|
1. REALTIME: Live market data training with continuous learning
|
|
2. BACKTEST: Historical data with sliding window simulation for fast training
|
|
|
|
Features:
|
|
- Multi-horizon predictions (1m, 5m, 15m, 60m)
|
|
- CNN, DQN, and COB RL model training
|
|
- Checkpoint management with model rotation
|
|
- Performance tracking and reporting
|
|
- Resumable training sessions
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import json
|
|
import argparse
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Any, Optional
|
|
from collections import deque
|
|
import asyncio
|
|
|
|
# Core components
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.multi_horizon_backtester import MultiHorizonBacktester
|
|
from core.multi_horizon_prediction_manager import MultiHorizonPredictionManager
|
|
from core.prediction_snapshot_storage import PredictionSnapshotStorage
|
|
from core.multi_horizon_trainer import MultiHorizonTrainer
|
|
|
|
# Model management
|
|
from NN.training.model_manager import create_model_manager
|
|
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('logs/training.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UnifiedTrainingRunner:
|
|
"""Unified training system supporting both realtime and backtesting modes"""
|
|
|
|
def __init__(self, mode: str = "realtime", symbol: str = "ETH/USDT"):
|
|
"""
|
|
Initialize the unified training runner
|
|
|
|
Args:
|
|
mode: "realtime" for live training or "backtest" for historical training
|
|
symbol: Trading symbol to train on
|
|
"""
|
|
self.mode = mode
|
|
self.symbol = symbol
|
|
self.start_time = datetime.now()
|
|
|
|
logger.info(f"Initializing Unified Training Runner - Mode: {mode.upper()}")
|
|
|
|
# Initialize core components
|
|
self.data_provider = DataProvider()
|
|
self.orchestrator = TradingOrchestrator(
|
|
data_provider=self.data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
|
|
# Initialize training components
|
|
self.backtester = MultiHorizonBacktester(self.data_provider)
|
|
self.prediction_manager = MultiHorizonPredictionManager(
|
|
data_provider=self.data_provider
|
|
)
|
|
self.snapshot_storage = PredictionSnapshotStorage()
|
|
self.trainer = MultiHorizonTrainer(
|
|
orchestrator=self.orchestrator,
|
|
snapshot_storage=self.snapshot_storage
|
|
)
|
|
|
|
# Initialize enhanced real-time training (used in both modes)
|
|
self.enhanced_training = None
|
|
if hasattr(self.orchestrator, 'enhanced_training_system'):
|
|
self.enhanced_training = self.orchestrator.enhanced_training_system
|
|
|
|
# Model checkpoint manager
|
|
self.checkpoint_manager = create_model_manager()
|
|
|
|
# Training configuration
|
|
self.config = {
|
|
'realtime': {
|
|
'checkpoint_interval_minutes': 30,
|
|
'backtest_interval_minutes': 60,
|
|
'performance_check_minutes': 15
|
|
},
|
|
'backtest': {
|
|
'window_size_hours': 24,
|
|
'step_size_hours': 1,
|
|
'batch_size': 64,
|
|
'save_interval_hours': 2
|
|
}
|
|
}
|
|
|
|
# Performance tracking
|
|
self.metrics = {
|
|
'training_sessions': [],
|
|
'backtest_results': [],
|
|
'model_checkpoints': [],
|
|
'prediction_accuracy': deque(maxlen=1000),
|
|
'training_losses': {'cnn': [], 'dqn': [], 'cob_rl': []}
|
|
}
|
|
|
|
# Training state
|
|
self.is_running = False
|
|
self.progress_file = Path('training_progress.json')
|
|
|
|
logger.info(f"Unified Training Runner initialized for {symbol}")
|
|
logger.info(f"Mode: {mode}, Enhanced Training: {self.enhanced_training is not None}")
|
|
|
|
def run_realtime_training(self, duration_hours: Optional[float] = None):
|
|
"""
|
|
Run continuous real-time training on live market data
|
|
|
|
Args:
|
|
duration_hours: How long to train (None = indefinite)
|
|
"""
|
|
logger.info("=" * 70)
|
|
logger.info("STARTING REALTIME TRAINING")
|
|
logger.info("=" * 70)
|
|
logger.info(f"Duration: {'indefinite' if duration_hours is None else f'{duration_hours} hours'}")
|
|
|
|
self.is_running = True
|
|
config = self.config['realtime']
|
|
|
|
last_checkpoint = time.time()
|
|
last_backtest = time.time()
|
|
last_perf_check = time.time()
|
|
|
|
try:
|
|
# Start enhanced training if available
|
|
if self.enhanced_training and hasattr(self.orchestrator, 'start_enhanced_training'):
|
|
self.orchestrator.start_enhanced_training()
|
|
logger.info("Enhanced real-time training started")
|
|
|
|
# Start multi-horizon prediction and training
|
|
self.prediction_manager.start()
|
|
self.trainer.start()
|
|
logger.info("Multi-horizon prediction and training started")
|
|
|
|
while self.is_running:
|
|
current_time = time.time()
|
|
elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600
|
|
|
|
# Check duration limit
|
|
if duration_hours and elapsed_hours >= duration_hours:
|
|
logger.info(f"Training duration completed: {elapsed_hours:.1f} hours")
|
|
break
|
|
|
|
# Periodic checkpoint save
|
|
if current_time - last_checkpoint > config['checkpoint_interval_minutes'] * 60:
|
|
self._save_checkpoint()
|
|
last_checkpoint = current_time
|
|
|
|
# Periodic backtest validation
|
|
if current_time - last_backtest > config['backtest_interval_minutes'] * 60:
|
|
accuracy = self._run_backtest_validation()
|
|
if accuracy is not None:
|
|
self.metrics['prediction_accuracy'].append(accuracy)
|
|
logger.info(f"Backtest accuracy at {elapsed_hours:.1f}h: {accuracy:.3%}")
|
|
last_backtest = current_time
|
|
|
|
# Performance check
|
|
if current_time - last_perf_check > config['performance_check_minutes'] * 60:
|
|
self._log_performance_metrics()
|
|
last_perf_check = current_time
|
|
|
|
# Sleep to reduce CPU usage
|
|
time.sleep(60)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
finally:
|
|
self._cleanup_training()
|
|
self._generate_final_report()
|
|
|
|
def run_backtest_training(self, start_date: datetime, end_date: datetime):
|
|
"""
|
|
Run fast backtesting with sliding window for bulk training
|
|
|
|
Args:
|
|
start_date: Start date for backtesting
|
|
end_date: End date for backtesting
|
|
"""
|
|
logger.info("=" * 70)
|
|
logger.info("STARTING BACKTEST TRAINING")
|
|
logger.info("=" * 70)
|
|
logger.info(f"Period: {start_date} to {end_date}")
|
|
|
|
config = self.config['backtest']
|
|
window_hours = config['window_size_hours']
|
|
step_hours = config['step_size_hours']
|
|
|
|
current_date = start_date
|
|
batch_count = 0
|
|
total_samples = 0
|
|
|
|
try:
|
|
while current_date < end_date:
|
|
window_end = current_date + timedelta(hours=window_hours)
|
|
|
|
if window_end > end_date:
|
|
break
|
|
|
|
batch_count += 1
|
|
logger.info(f"Batch {batch_count}: {current_date} to {window_end}")
|
|
|
|
# Fetch historical data for window
|
|
data = self._fetch_window_data(current_date, window_end)
|
|
|
|
if data and len(data) > 0:
|
|
# Simulate real-time data flow through sliding window
|
|
samples_trained = self._train_on_window(data)
|
|
total_samples += samples_trained
|
|
|
|
logger.info(f"Trained on {samples_trained} samples in window")
|
|
|
|
# Save checkpoint periodically
|
|
elapsed_hours = (window_end - start_date).total_seconds() / 3600
|
|
if elapsed_hours % config['save_interval_hours'] == 0:
|
|
self._save_checkpoint()
|
|
logger.info(f"Checkpoint saved at {elapsed_hours:.1f}h")
|
|
|
|
# Move window forward
|
|
current_date += timedelta(hours=step_hours)
|
|
|
|
logger.info(f"Backtest training complete: {batch_count} batches, {total_samples} samples")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in backtest training: {e}")
|
|
raise
|
|
finally:
|
|
self._generate_final_report()
|
|
|
|
def _fetch_window_data(self, start: datetime, end: datetime) -> List[Dict]:
|
|
"""Fetch historical data for a time window"""
|
|
try:
|
|
# Fetch from data provider with real market data
|
|
data = self.data_provider.get_historical_data(
|
|
symbol=self.symbol,
|
|
timeframe='1m',
|
|
start_time=start,
|
|
end_time=end
|
|
)
|
|
|
|
if data is None or len(data) == 0:
|
|
logger.warning(f"No data available for {start} to {end}")
|
|
return []
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching window data: {e}")
|
|
return []
|
|
|
|
def _train_on_window(self, data: List[Dict]) -> int:
|
|
"""
|
|
Train models on a sliding window of data
|
|
|
|
Args:
|
|
data: List of market data points
|
|
|
|
Returns:
|
|
Number of samples trained on
|
|
"""
|
|
samples_trained = 0
|
|
|
|
# Simulate real-time flow through data
|
|
for i in range(len(data) - 1):
|
|
current = data[i]
|
|
next_data = data[i + 1]
|
|
|
|
# Create prediction snapshot
|
|
snapshot = {
|
|
'timestamp': current.get('timestamp'),
|
|
'price': current.get('close', 0),
|
|
'volume': current.get('volume', 0),
|
|
'symbol': self.symbol
|
|
}
|
|
|
|
# Store snapshot for later training
|
|
self.snapshot_storage.store_snapshot(snapshot)
|
|
|
|
# When we have outcome, train the models
|
|
if i > 0: # Need previous snapshot for outcome
|
|
prev_snapshot = data[i - 1]
|
|
outcome = {
|
|
'actual_price': current.get('close', 0),
|
|
'timestamp': current.get('timestamp')
|
|
}
|
|
|
|
# Train via multi-horizon trainer
|
|
self.trainer.train_on_outcome(prev_snapshot, outcome)
|
|
samples_trained += 1
|
|
|
|
return samples_trained
|
|
|
|
def _run_backtest_validation(self) -> Optional[float]:
|
|
"""Run backtest on recent data to validate model performance"""
|
|
try:
|
|
end_date = datetime.now()
|
|
start_date = end_date - timedelta(hours=24)
|
|
|
|
results = self.backtester.run_backtest(
|
|
symbol=self.symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
horizons=[1, 5, 15, 60] # minutes
|
|
)
|
|
|
|
if results and 'accuracy' in results:
|
|
return results['accuracy']
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in backtest validation: {e}")
|
|
return None
|
|
|
|
def _save_checkpoint(self):
|
|
"""Save model checkpoints with rotation"""
|
|
try:
|
|
checkpoint_data = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'mode': self.mode,
|
|
'elapsed_hours': (datetime.now() - self.start_time).total_seconds() / 3600,
|
|
'metrics': {
|
|
'prediction_accuracy': list(self.metrics['prediction_accuracy'])[-10:],
|
|
'total_training_samples': sum(
|
|
len(losses) for losses in self.metrics['training_losses'].values()
|
|
)
|
|
}
|
|
}
|
|
|
|
# Use model manager for checkpoint rotation (keeps best 5)
|
|
self.checkpoint_manager.save_checkpoint(
|
|
model=self.orchestrator,
|
|
metadata=checkpoint_data
|
|
)
|
|
|
|
self.metrics['model_checkpoints'].append(checkpoint_data)
|
|
logger.info("Checkpoint saved successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving checkpoint: {e}")
|
|
|
|
def _log_performance_metrics(self):
|
|
"""Log current performance metrics"""
|
|
elapsed_hours = (datetime.now() - self.start_time).total_seconds() / 3600
|
|
|
|
avg_accuracy = 0
|
|
if self.metrics['prediction_accuracy']:
|
|
avg_accuracy = sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy'])
|
|
|
|
logger.info("=" * 50)
|
|
logger.info(f"Performance Metrics @ {elapsed_hours:.1f}h")
|
|
logger.info(f" Avg Prediction Accuracy: {avg_accuracy:.3%}")
|
|
logger.info(f" Total Checkpoints: {len(self.metrics['model_checkpoints'])}")
|
|
logger.info(f" CNN Training Samples: {len(self.metrics['training_losses']['cnn'])}")
|
|
logger.info(f" DQN Training Samples: {len(self.metrics['training_losses']['dqn'])}")
|
|
logger.info("=" * 50)
|
|
|
|
def _cleanup_training(self):
|
|
"""Clean up training resources"""
|
|
logger.info("Cleaning up training resources...")
|
|
|
|
# Stop prediction and training
|
|
if hasattr(self.prediction_manager, 'stop'):
|
|
self.prediction_manager.stop()
|
|
if hasattr(self.trainer, 'stop'):
|
|
self.trainer.stop()
|
|
|
|
# Save final checkpoint
|
|
self._save_checkpoint()
|
|
|
|
logger.info("Training cleanup complete")
|
|
|
|
def _generate_final_report(self):
|
|
"""Generate final training report"""
|
|
report = {
|
|
'mode': self.mode,
|
|
'symbol': self.symbol,
|
|
'start_time': self.start_time.isoformat(),
|
|
'end_time': datetime.now().isoformat(),
|
|
'duration_hours': (datetime.now() - self.start_time).total_seconds() / 3600,
|
|
'metrics': {
|
|
'total_checkpoints': len(self.metrics['model_checkpoints']),
|
|
'total_backtest_runs': len(self.metrics['backtest_results']),
|
|
'final_accuracy': list(self.metrics['prediction_accuracy'])[-1] if self.metrics['prediction_accuracy'] else 0,
|
|
'avg_accuracy': sum(self.metrics['prediction_accuracy']) / len(self.metrics['prediction_accuracy']) if self.metrics['prediction_accuracy'] else 0
|
|
}
|
|
}
|
|
|
|
report_file = Path(f'training_report_{self.mode}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json')
|
|
with open(report_file, 'w') as f:
|
|
json.dump(report, f, indent=2)
|
|
|
|
logger.info("=" * 70)
|
|
logger.info("TRAINING COMPLETE")
|
|
logger.info("=" * 70)
|
|
logger.info(f"Mode: {self.mode}")
|
|
logger.info(f"Duration: {report['duration_hours']:.2f} hours")
|
|
logger.info(f"Final Accuracy: {report['metrics']['final_accuracy']:.3%}")
|
|
logger.info(f"Avg Accuracy: {report['metrics']['avg_accuracy']:.3%}")
|
|
logger.info(f"Report saved to: {report_file}")
|
|
logger.info("=" * 70)
|
|
|
|
|
|
def main():
|
|
"""Main entry point for training runner"""
|
|
parser = argparse.ArgumentParser(description="Unified Training Runner")
|
|
parser.add_argument(
|
|
'--mode',
|
|
type=str,
|
|
choices=['realtime', 'backtest'],
|
|
default='realtime',
|
|
help='Training mode: realtime or backtest'
|
|
)
|
|
parser.add_argument(
|
|
'--symbol',
|
|
type=str,
|
|
default='ETH/USDT',
|
|
help='Trading symbol'
|
|
)
|
|
parser.add_argument(
|
|
'--duration',
|
|
type=float,
|
|
default=None,
|
|
help='Training duration in hours (realtime mode only)'
|
|
)
|
|
parser.add_argument(
|
|
'--start-date',
|
|
type=str,
|
|
default=None,
|
|
help='Start date for backtest (YYYY-MM-DD)'
|
|
)
|
|
parser.add_argument(
|
|
'--end-date',
|
|
type=str,
|
|
default=None,
|
|
help='End date for backtest (YYYY-MM-DD)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create training runner
|
|
runner = UnifiedTrainingRunner(mode=args.mode, symbol=args.symbol)
|
|
|
|
if args.mode == 'realtime':
|
|
runner.run_realtime_training(duration_hours=args.duration)
|
|
else: # backtest
|
|
if not args.start_date or not args.end_date:
|
|
logger.error("Backtest mode requires --start-date and --end-date")
|
|
return
|
|
|
|
start = datetime.strptime(args.start_date, '%Y-%m-%d')
|
|
end = datetime.strptime(args.end_date, '%Y-%m-%d')
|
|
|
|
runner.run_backtest_training(start_date=start, end_date=end)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|