537 lines
22 KiB
Python
537 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Multi-Horizon Trainer
|
|
|
|
This module trains models using stored prediction snapshots when outcomes are known.
|
|
It handles training for different time horizons and model types.
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
import numpy as np
|
|
import torch
|
|
from collections import defaultdict
|
|
|
|
from .prediction_snapshot_storage import PredictionSnapshotStorage
|
|
from .multi_horizon_prediction_manager import PredictionSnapshot
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MultiHorizonTrainer:
|
|
"""Trainer for multi-horizon predictions using stored snapshots"""
|
|
|
|
def __init__(self, orchestrator=None, snapshot_storage: Optional[PredictionSnapshotStorage] = None):
|
|
"""Initialize the multi-horizon trainer"""
|
|
self.orchestrator = orchestrator
|
|
self.snapshot_storage = snapshot_storage or PredictionSnapshotStorage()
|
|
|
|
# Training configuration
|
|
self.batch_size = 32
|
|
self.min_batch_size = 10
|
|
self.training_interval_seconds = 300 # 5 minutes
|
|
self.max_training_age_hours = 24 # Don't train on predictions older than 24 hours
|
|
|
|
# Model training settings
|
|
self.learning_rate = 0.001
|
|
self.epochs_per_batch = 5
|
|
self.validation_split = 0.2
|
|
|
|
# Training state
|
|
self.training_active = False
|
|
self.training_thread = None
|
|
self.last_training_time = 0.0
|
|
|
|
# Performance tracking
|
|
self.training_stats = {
|
|
'total_training_sessions': 0,
|
|
'models_trained': defaultdict(int),
|
|
'training_accuracy': defaultdict(list),
|
|
'loss_history': defaultdict(list),
|
|
'last_training_time': None
|
|
}
|
|
|
|
logger.info("MultiHorizonTrainer initialized")
|
|
|
|
def start(self):
|
|
"""Start the training system"""
|
|
if self.training_active:
|
|
logger.warning("Training system already active")
|
|
return
|
|
|
|
self.training_active = True
|
|
self.training_thread = threading.Thread(
|
|
target=self._training_loop,
|
|
daemon=True,
|
|
name="MultiHorizonTrainer"
|
|
)
|
|
self.training_thread.start()
|
|
logger.info("MultiHorizonTrainer started")
|
|
|
|
def stop(self):
|
|
"""Stop the training system"""
|
|
self.training_active = False
|
|
if self.training_thread and self.training_thread.is_alive():
|
|
self.training_thread.join(timeout=10)
|
|
logger.info("MultiHorizonTrainer stopped")
|
|
|
|
def _training_loop(self):
|
|
"""Main training loop"""
|
|
while self.training_active:
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# Check if it's time for training
|
|
if current_time - self.last_training_time >= self.training_interval_seconds:
|
|
self._run_training_session()
|
|
self.last_training_time = current_time
|
|
|
|
# Sleep before next check
|
|
time.sleep(60) # Check every minute
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training loop: {e}")
|
|
time.sleep(300) # Longer sleep on error
|
|
|
|
def _run_training_session(self):
|
|
"""Run a complete training session"""
|
|
try:
|
|
logger.info("Starting multi-horizon training session")
|
|
|
|
training_results = {}
|
|
|
|
# Train each horizon separately
|
|
horizons = [1, 5, 15, 60]
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
|
|
for horizon in horizons:
|
|
for symbol in symbols:
|
|
try:
|
|
horizon_results = self._train_horizon_models(horizon, symbol)
|
|
if horizon_results:
|
|
training_results[f"{horizon}m_{symbol}"] = horizon_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training {horizon}m models for {symbol}: {e}")
|
|
|
|
# Update statistics
|
|
self.training_stats['total_training_sessions'] += 1
|
|
self.training_stats['last_training_time'] = datetime.now()
|
|
|
|
if training_results:
|
|
logger.info(f"Training session completed: {len(training_results)} model updates")
|
|
for key, results in training_results.items():
|
|
logger.info(f" {key}: {results}")
|
|
else:
|
|
logger.debug("No models were trained in this session")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training session: {e}")
|
|
|
|
def _train_horizon_models(self, horizon_minutes: int, symbol: str) -> Dict[str, Any]:
|
|
"""Train models for a specific horizon and symbol"""
|
|
results = {}
|
|
|
|
# Get training batch
|
|
snapshots = self.snapshot_storage.get_training_batch(
|
|
horizon_minutes=horizon_minutes,
|
|
symbol=symbol,
|
|
batch_size=self.batch_size,
|
|
min_confidence=0.3
|
|
)
|
|
|
|
if len(snapshots) < self.min_batch_size:
|
|
logger.debug(f"Insufficient training data for {horizon_minutes}m {symbol}: {len(snapshots)} snapshots")
|
|
return results
|
|
|
|
logger.info(f"Training {horizon_minutes}m models for {symbol} with {len(snapshots)} snapshots")
|
|
|
|
# Train CNN model
|
|
if self.orchestrator and hasattr(self.orchestrator, 'cnn_model'):
|
|
try:
|
|
cnn_results = self._train_cnn_model(snapshots, horizon_minutes, symbol)
|
|
if cnn_results:
|
|
results['cnn'] = cnn_results
|
|
self.training_stats['models_trained']['cnn'] += 1
|
|
except Exception as e:
|
|
logger.error(f"CNN training failed for {horizon_minutes}m {symbol}: {e}")
|
|
|
|
# Train RL model
|
|
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
|
try:
|
|
rl_results = self._train_rl_model(snapshots, horizon_minutes, symbol)
|
|
if rl_results:
|
|
results['rl'] = rl_results
|
|
self.training_stats['models_trained']['rl'] += 1
|
|
except Exception as e:
|
|
logger.error(f"RL training failed for {horizon_minutes}m {symbol}: {e}")
|
|
|
|
return results
|
|
|
|
def _train_cnn_model(self, snapshots: List[PredictionSnapshot],
|
|
horizon_minutes: int, symbol: str) -> Dict[str, Any]:
|
|
"""Train CNN model using prediction snapshots"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model'):
|
|
return None
|
|
|
|
cnn_model = self.orchestrator.cnn_model
|
|
|
|
# Prepare training data
|
|
features_list = []
|
|
targets_list = []
|
|
|
|
for snapshot in snapshots:
|
|
# Extract CNN features
|
|
features = snapshot.model_inputs.get('cnn_features')
|
|
if features is None:
|
|
continue
|
|
|
|
# Create target based on prediction accuracy
|
|
if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None:
|
|
# Calculate prediction error
|
|
pred_range = snapshot.predicted_max_price - snapshot.predicted_min_price
|
|
actual_range = snapshot.actual_max_price - snapshot.actual_min_price
|
|
|
|
# Simple target: 1 if prediction was reasonably accurate, 0 otherwise
|
|
range_overlap = self._calculate_range_overlap(
|
|
(snapshot.predicted_min_price, snapshot.predicted_max_price),
|
|
(snapshot.actual_min_price, snapshot.actual_max_price)
|
|
)
|
|
|
|
target = 1 if range_overlap > 0.3 else 0 # 30% overlap threshold
|
|
|
|
features_list.append(features)
|
|
targets_list.append(target)
|
|
|
|
if len(features_list) < self.min_batch_size:
|
|
return {'error': 'Insufficient training data'}
|
|
|
|
# Convert to tensors
|
|
features_array = np.array(features_list, dtype=np.float32)
|
|
targets_array = np.array(targets_list, dtype=np.float32)
|
|
|
|
# Split into train/validation
|
|
split_idx = int(len(features_array) * (1 - self.validation_split))
|
|
train_features = features_array[:split_idx]
|
|
train_targets = targets_array[:split_idx]
|
|
val_features = features_array[split_idx:]
|
|
val_targets = targets_array[split_idx:]
|
|
|
|
# Training loop
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
cnn_model.to(device)
|
|
|
|
if not hasattr(cnn_model, 'optimizer'):
|
|
cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=self.learning_rate)
|
|
|
|
criterion = torch.nn.BCELoss() # Binary classification
|
|
|
|
train_losses = []
|
|
val_accuracies = []
|
|
|
|
for epoch in range(self.epochs_per_batch):
|
|
# Training step
|
|
cnn_model.train()
|
|
cnn_model.optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
inputs = torch.FloatTensor(train_features).to(device)
|
|
targets = torch.FloatTensor(train_targets).to(device)
|
|
|
|
# Handle different model outputs
|
|
outputs = cnn_model(inputs)
|
|
if isinstance(outputs, dict):
|
|
if 'main_output' in outputs:
|
|
logits = outputs['main_output']
|
|
else:
|
|
logits = list(outputs.values())[0]
|
|
else:
|
|
logits = outputs
|
|
|
|
# Apply sigmoid for binary classification
|
|
predictions = torch.sigmoid(logits.squeeze())
|
|
|
|
loss = criterion(predictions, targets)
|
|
loss.backward()
|
|
cnn_model.optimizer.step()
|
|
|
|
train_losses.append(loss.item())
|
|
|
|
# Validation step
|
|
if len(val_features) > 0:
|
|
cnn_model.eval()
|
|
with torch.no_grad():
|
|
val_inputs = torch.FloatTensor(val_features).to(device)
|
|
val_targets_tensor = torch.FloatTensor(val_targets).to(device)
|
|
|
|
val_outputs = cnn_model(val_inputs)
|
|
if isinstance(val_outputs, dict):
|
|
if 'main_output' in val_outputs:
|
|
val_logits = val_outputs['main_output']
|
|
else:
|
|
val_logits = list(val_outputs.values())[0]
|
|
else:
|
|
val_logits = val_outputs
|
|
|
|
val_predictions = torch.sigmoid(val_logits.squeeze())
|
|
val_binary_preds = (val_predictions > 0.5).float()
|
|
val_accuracy = (val_binary_preds == val_targets_tensor).float().mean().item()
|
|
val_accuracies.append(val_accuracy)
|
|
|
|
# Calculate final metrics
|
|
avg_train_loss = np.mean(train_losses)
|
|
final_val_accuracy = val_accuracies[-1] if val_accuracies else 0.0
|
|
|
|
self.training_stats['loss_history']['cnn'].append(avg_train_loss)
|
|
self.training_stats['training_accuracy']['cnn'].append(final_val_accuracy)
|
|
|
|
results = {
|
|
'epochs': self.epochs_per_batch,
|
|
'final_loss': avg_train_loss,
|
|
'validation_accuracy': final_val_accuracy,
|
|
'samples_used': len(features_list)
|
|
}
|
|
|
|
logger.info(f"CNN training completed: loss={avg_train_loss:.4f}, val_acc={final_val_accuracy:.2f}")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training CNN model: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _train_rl_model(self, snapshots: List[PredictionSnapshot],
|
|
horizon_minutes: int, symbol: str) -> Dict[str, Any]:
|
|
"""Train RL model using prediction snapshots"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent'):
|
|
return None
|
|
|
|
rl_agent = self.orchestrator.rl_agent
|
|
|
|
# Prepare RL training data
|
|
experiences = []
|
|
|
|
for snapshot in snapshots:
|
|
# Extract RL state
|
|
state = snapshot.model_inputs.get('rl_state')
|
|
if state is None:
|
|
continue
|
|
|
|
# Determine action from prediction
|
|
# For min/max prediction, we can derive action from predicted direction
|
|
predicted_range = snapshot.predicted_max_price - snapshot.predicted_min_price
|
|
current_price = snapshot.current_price
|
|
|
|
# Simple action derivation: if predicted range is mostly above current price, BUY
|
|
# if mostly below, SELL, else HOLD
|
|
range_center = (snapshot.predicted_min_price + snapshot.predicted_max_price) / 2
|
|
|
|
if range_center > current_price * 1.002: # 0.2% threshold
|
|
action = 0 # BUY
|
|
elif range_center < current_price * 0.998:
|
|
action = 1 # SELL
|
|
else:
|
|
action = 2 # HOLD
|
|
|
|
# Calculate reward based on prediction accuracy
|
|
if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None:
|
|
actual_center = (snapshot.actual_min_price + snapshot.actual_max_price) / 2
|
|
|
|
# Reward based on how well we predicted the price movement direction
|
|
predicted_direction = 1 if range_center > current_price else -1 if range_center < current_price else 0
|
|
actual_direction = 1 if actual_center > current_price else -1 if actual_center < current_price else 0
|
|
|
|
if predicted_direction == actual_direction:
|
|
reward = snapshot.confidence # Positive reward scaled by confidence
|
|
else:
|
|
reward = -snapshot.confidence # Negative reward scaled by confidence
|
|
|
|
# Additional reward based on range accuracy
|
|
range_overlap = self._calculate_range_overlap(
|
|
(snapshot.predicted_min_price, snapshot.predicted_max_price),
|
|
(snapshot.actual_min_price, snapshot.actual_max_price)
|
|
)
|
|
reward += range_overlap * 0.5 # Bonus for accurate range prediction
|
|
|
|
# Create next state (simplified)
|
|
next_state = state.copy()
|
|
|
|
experiences.append((state, action, reward, next_state, True)) # done=True
|
|
|
|
if len(experiences) < self.min_batch_size:
|
|
return {'error': 'Insufficient training data'}
|
|
|
|
# Add experiences to RL agent memory
|
|
experiences_added = 0
|
|
for state, action, reward, next_state, done in experiences:
|
|
try:
|
|
if hasattr(rl_agent, 'store_experience'):
|
|
rl_agent.store_experience(
|
|
state=np.array(state),
|
|
action=action,
|
|
reward=reward,
|
|
next_state=np.array(next_state),
|
|
done=done
|
|
)
|
|
experiences_added += 1
|
|
elif hasattr(rl_agent, 'remember'):
|
|
rl_agent.remember(np.array(state), action, reward, np.array(next_state), done)
|
|
experiences_added += 1
|
|
except Exception as e:
|
|
logger.debug(f"Error adding RL experience: {e}")
|
|
|
|
# Perform training steps
|
|
training_losses = []
|
|
if hasattr(rl_agent, 'replay') and experiences_added > 0:
|
|
try:
|
|
for _ in range(min(5, experiences_added // 8)): # Conservative training
|
|
loss = rl_agent.replay(batch_size=min(32, experiences_added))
|
|
if loss is not None:
|
|
training_losses.append(loss)
|
|
except Exception as e:
|
|
logger.debug(f"RL training step failed: {e}")
|
|
|
|
avg_loss = np.mean(training_losses) if training_losses else 0.0
|
|
|
|
results = {
|
|
'experiences_added': experiences_added,
|
|
'training_steps': len(training_losses),
|
|
'avg_loss': avg_loss,
|
|
'samples_used': len(experiences)
|
|
}
|
|
|
|
logger.info(f"RL training completed: {experiences_added} experiences, avg_loss={avg_loss:.4f}")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training RL model: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float:
|
|
"""Calculate overlap between two price ranges (0.0 to 1.0)"""
|
|
try:
|
|
min1, max1 = range1
|
|
min2, max2 = range2
|
|
|
|
# Find overlap
|
|
overlap_min = max(min1, min2)
|
|
overlap_max = min(max1, max2)
|
|
|
|
if overlap_max <= overlap_min:
|
|
return 0.0
|
|
|
|
overlap_size = overlap_max - overlap_min
|
|
union_size = max(max1, max2) - min(min1, min2)
|
|
|
|
return overlap_size / union_size if union_size > 0 else 0.0
|
|
|
|
except Exception:
|
|
return 0.0
|
|
|
|
def force_training_session(self, horizon_minutes: Optional[int] = None,
|
|
symbol: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Force a training session for specific parameters"""
|
|
try:
|
|
logger.info(f"Forcing training session: horizon={horizon_minutes}, symbol={symbol}")
|
|
|
|
results = {}
|
|
|
|
horizons = [horizon_minutes] if horizon_minutes else [1, 5, 15, 60]
|
|
symbols = [symbol] if symbol else ['ETH/USDT', 'BTC/USDT']
|
|
|
|
for h in horizons:
|
|
for s in symbols:
|
|
try:
|
|
horizon_results = self._train_horizon_models(h, s)
|
|
if horizon_results:
|
|
results[f"{h}m_{s}"] = horizon_results
|
|
except Exception as e:
|
|
logger.error(f"Error in forced training for {h}m {s}: {e}")
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in forced training session: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def get_training_stats(self) -> Dict[str, Any]:
|
|
"""Get training statistics"""
|
|
stats = dict(self.training_stats)
|
|
stats['is_training_active'] = self.training_active
|
|
|
|
# Calculate averages
|
|
for model_type in ['cnn', 'rl']:
|
|
if stats['training_accuracy'][model_type]:
|
|
stats[f'{model_type}_avg_accuracy'] = np.mean(stats['training_accuracy'][model_type])
|
|
else:
|
|
stats[f'{model_type}_avg_accuracy'] = 0.0
|
|
|
|
if stats['loss_history'][model_type]:
|
|
stats[f'{model_type}_avg_loss'] = np.mean(stats['loss_history'][model_type])
|
|
else:
|
|
stats[f'{model_type}_avg_loss'] = 0.0
|
|
|
|
return stats
|
|
|
|
def validate_recent_predictions(self):
|
|
"""Validate predictions that should have outcomes available"""
|
|
try:
|
|
# Get pending snapshots
|
|
pending_snapshots = self.snapshot_storage.get_pending_validation_snapshots()
|
|
|
|
if not pending_snapshots:
|
|
return
|
|
|
|
logger.info(f"Validating {len(pending_snapshots)} pending predictions")
|
|
|
|
# Group by symbol for efficient data access
|
|
by_symbol = defaultdict(list)
|
|
for snapshot in pending_snapshots:
|
|
by_symbol[snapshot.symbol].append(snapshot)
|
|
|
|
# Validate each symbol
|
|
for symbol, snapshots in by_symbol.items():
|
|
try:
|
|
self._validate_symbol_predictions(symbol, snapshots)
|
|
except Exception as e:
|
|
logger.error(f"Error validating predictions for {symbol}: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating recent predictions: {e}")
|
|
|
|
def _validate_symbol_predictions(self, symbol: str, snapshots: List[PredictionSnapshot]):
|
|
"""Validate predictions for a specific symbol"""
|
|
try:
|
|
# Get historical data for the validation period
|
|
# This is a simplified approach - in practice you'd need to get the price range
|
|
# during the prediction horizon
|
|
|
|
for snapshot in snapshots:
|
|
try:
|
|
# For now, use a simple validation approach
|
|
# In a real implementation, you'd query historical data for the exact time range
|
|
# and calculate actual min/max prices during the prediction horizon
|
|
|
|
# Simplified: assume current price as both min and max (not accurate but functional)
|
|
current_time = datetime.now()
|
|
current_price = snapshot.current_price # Placeholder
|
|
|
|
# Update snapshot with "outcome"
|
|
self.snapshot_storage.update_snapshot_outcome(
|
|
snapshot.prediction_id,
|
|
current_price, # actual_min
|
|
current_price, # actual_max
|
|
current_time
|
|
)
|
|
|
|
logger.debug(f"Validated prediction {snapshot.prediction_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating snapshot {snapshot.prediction_id}: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating symbol predictions for {symbol}: {e}")
|