main cleanup

This commit is contained in:
Dobromir Popov
2025-09-30 23:56:36 +03:00
parent 468a2c2a66
commit 608da8233f
52 changed files with 5308 additions and 9985 deletions

View File

@@ -0,0 +1,536 @@
#!/usr/bin/env python3
"""
Multi-Horizon Trainer
This module trains models using stored prediction snapshots when outcomes are known.
It handles training for different time horizons and model types.
"""
import logging
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
import numpy as np
import torch
from collections import defaultdict
from .prediction_snapshot_storage import PredictionSnapshotStorage
from .multi_horizon_prediction_manager import PredictionSnapshot
logger = logging.getLogger(__name__)
class MultiHorizonTrainer:
"""Trainer for multi-horizon predictions using stored snapshots"""
def __init__(self, orchestrator=None, snapshot_storage: Optional[PredictionSnapshotStorage] = None):
"""Initialize the multi-horizon trainer"""
self.orchestrator = orchestrator
self.snapshot_storage = snapshot_storage or PredictionSnapshotStorage()
# Training configuration
self.batch_size = 32
self.min_batch_size = 10
self.training_interval_seconds = 300 # 5 minutes
self.max_training_age_hours = 24 # Don't train on predictions older than 24 hours
# Model training settings
self.learning_rate = 0.001
self.epochs_per_batch = 5
self.validation_split = 0.2
# Training state
self.training_active = False
self.training_thread = None
self.last_training_time = 0.0
# Performance tracking
self.training_stats = {
'total_training_sessions': 0,
'models_trained': defaultdict(int),
'training_accuracy': defaultdict(list),
'loss_history': defaultdict(list),
'last_training_time': None
}
logger.info("MultiHorizonTrainer initialized")
def start(self):
"""Start the training system"""
if self.training_active:
logger.warning("Training system already active")
return
self.training_active = True
self.training_thread = threading.Thread(
target=self._training_loop,
daemon=True,
name="MultiHorizonTrainer"
)
self.training_thread.start()
logger.info("MultiHorizonTrainer started")
def stop(self):
"""Stop the training system"""
self.training_active = False
if self.training_thread and self.training_thread.is_alive():
self.training_thread.join(timeout=10)
logger.info("MultiHorizonTrainer stopped")
def _training_loop(self):
"""Main training loop"""
while self.training_active:
try:
current_time = time.time()
# Check if it's time for training
if current_time - self.last_training_time >= self.training_interval_seconds:
self._run_training_session()
self.last_training_time = current_time
# Sleep before next check
time.sleep(60) # Check every minute
except Exception as e:
logger.error(f"Error in training loop: {e}")
time.sleep(300) # Longer sleep on error
def _run_training_session(self):
"""Run a complete training session"""
try:
logger.info("Starting multi-horizon training session")
training_results = {}
# Train each horizon separately
horizons = [1, 5, 15, 60]
symbols = ['ETH/USDT', 'BTC/USDT']
for horizon in horizons:
for symbol in symbols:
try:
horizon_results = self._train_horizon_models(horizon, symbol)
if horizon_results:
training_results[f"{horizon}m_{symbol}"] = horizon_results
except Exception as e:
logger.error(f"Error training {horizon}m models for {symbol}: {e}")
# Update statistics
self.training_stats['total_training_sessions'] += 1
self.training_stats['last_training_time'] = datetime.now()
if training_results:
logger.info(f"Training session completed: {len(training_results)} model updates")
for key, results in training_results.items():
logger.info(f" {key}: {results}")
else:
logger.debug("No models were trained in this session")
except Exception as e:
logger.error(f"Error in training session: {e}")
def _train_horizon_models(self, horizon_minutes: int, symbol: str) -> Dict[str, Any]:
"""Train models for a specific horizon and symbol"""
results = {}
# Get training batch
snapshots = self.snapshot_storage.get_training_batch(
horizon_minutes=horizon_minutes,
symbol=symbol,
batch_size=self.batch_size,
min_confidence=0.3
)
if len(snapshots) < self.min_batch_size:
logger.debug(f"Insufficient training data for {horizon_minutes}m {symbol}: {len(snapshots)} snapshots")
return results
logger.info(f"Training {horizon_minutes}m models for {symbol} with {len(snapshots)} snapshots")
# Train CNN model
if self.orchestrator and hasattr(self.orchestrator, 'cnn_model'):
try:
cnn_results = self._train_cnn_model(snapshots, horizon_minutes, symbol)
if cnn_results:
results['cnn'] = cnn_results
self.training_stats['models_trained']['cnn'] += 1
except Exception as e:
logger.error(f"CNN training failed for {horizon_minutes}m {symbol}: {e}")
# Train RL model
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
try:
rl_results = self._train_rl_model(snapshots, horizon_minutes, symbol)
if rl_results:
results['rl'] = rl_results
self.training_stats['models_trained']['rl'] += 1
except Exception as e:
logger.error(f"RL training failed for {horizon_minutes}m {symbol}: {e}")
return results
def _train_cnn_model(self, snapshots: List[PredictionSnapshot],
horizon_minutes: int, symbol: str) -> Dict[str, Any]:
"""Train CNN model using prediction snapshots"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model'):
return None
cnn_model = self.orchestrator.cnn_model
# Prepare training data
features_list = []
targets_list = []
for snapshot in snapshots:
# Extract CNN features
features = snapshot.model_inputs.get('cnn_features')
if features is None:
continue
# Create target based on prediction accuracy
if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None:
# Calculate prediction error
pred_range = snapshot.predicted_max_price - snapshot.predicted_min_price
actual_range = snapshot.actual_max_price - snapshot.actual_min_price
# Simple target: 1 if prediction was reasonably accurate, 0 otherwise
range_overlap = self._calculate_range_overlap(
(snapshot.predicted_min_price, snapshot.predicted_max_price),
(snapshot.actual_min_price, snapshot.actual_max_price)
)
target = 1 if range_overlap > 0.3 else 0 # 30% overlap threshold
features_list.append(features)
targets_list.append(target)
if len(features_list) < self.min_batch_size:
return {'error': 'Insufficient training data'}
# Convert to tensors
features_array = np.array(features_list, dtype=np.float32)
targets_array = np.array(targets_list, dtype=np.float32)
# Split into train/validation
split_idx = int(len(features_array) * (1 - self.validation_split))
train_features = features_array[:split_idx]
train_targets = targets_array[:split_idx]
val_features = features_array[split_idx:]
val_targets = targets_array[split_idx:]
# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_model.to(device)
if not hasattr(cnn_model, 'optimizer'):
cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=self.learning_rate)
criterion = torch.nn.BCELoss() # Binary classification
train_losses = []
val_accuracies = []
for epoch in range(self.epochs_per_batch):
# Training step
cnn_model.train()
cnn_model.optimizer.zero_grad()
# Forward pass
inputs = torch.FloatTensor(train_features).to(device)
targets = torch.FloatTensor(train_targets).to(device)
# Handle different model outputs
outputs = cnn_model(inputs)
if isinstance(outputs, dict):
if 'main_output' in outputs:
logits = outputs['main_output']
else:
logits = list(outputs.values())[0]
else:
logits = outputs
# Apply sigmoid for binary classification
predictions = torch.sigmoid(logits.squeeze())
loss = criterion(predictions, targets)
loss.backward()
cnn_model.optimizer.step()
train_losses.append(loss.item())
# Validation step
if len(val_features) > 0:
cnn_model.eval()
with torch.no_grad():
val_inputs = torch.FloatTensor(val_features).to(device)
val_targets_tensor = torch.FloatTensor(val_targets).to(device)
val_outputs = cnn_model(val_inputs)
if isinstance(val_outputs, dict):
if 'main_output' in val_outputs:
val_logits = val_outputs['main_output']
else:
val_logits = list(val_outputs.values())[0]
else:
val_logits = val_outputs
val_predictions = torch.sigmoid(val_logits.squeeze())
val_binary_preds = (val_predictions > 0.5).float()
val_accuracy = (val_binary_preds == val_targets_tensor).float().mean().item()
val_accuracies.append(val_accuracy)
# Calculate final metrics
avg_train_loss = np.mean(train_losses)
final_val_accuracy = val_accuracies[-1] if val_accuracies else 0.0
self.training_stats['loss_history']['cnn'].append(avg_train_loss)
self.training_stats['training_accuracy']['cnn'].append(final_val_accuracy)
results = {
'epochs': self.epochs_per_batch,
'final_loss': avg_train_loss,
'validation_accuracy': final_val_accuracy,
'samples_used': len(features_list)
}
logger.info(f"CNN training completed: loss={avg_train_loss:.4f}, val_acc={final_val_accuracy:.2f}")
return results
except Exception as e:
logger.error(f"Error training CNN model: {e}")
return {'error': str(e)}
def _train_rl_model(self, snapshots: List[PredictionSnapshot],
horizon_minutes: int, symbol: str) -> Dict[str, Any]:
"""Train RL model using prediction snapshots"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent'):
return None
rl_agent = self.orchestrator.rl_agent
# Prepare RL training data
experiences = []
for snapshot in snapshots:
# Extract RL state
state = snapshot.model_inputs.get('rl_state')
if state is None:
continue
# Determine action from prediction
# For min/max prediction, we can derive action from predicted direction
predicted_range = snapshot.predicted_max_price - snapshot.predicted_min_price
current_price = snapshot.current_price
# Simple action derivation: if predicted range is mostly above current price, BUY
# if mostly below, SELL, else HOLD
range_center = (snapshot.predicted_min_price + snapshot.predicted_max_price) / 2
if range_center > current_price * 1.002: # 0.2% threshold
action = 0 # BUY
elif range_center < current_price * 0.998:
action = 1 # SELL
else:
action = 2 # HOLD
# Calculate reward based on prediction accuracy
if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None:
actual_center = (snapshot.actual_min_price + snapshot.actual_max_price) / 2
# Reward based on how well we predicted the price movement direction
predicted_direction = 1 if range_center > current_price else -1 if range_center < current_price else 0
actual_direction = 1 if actual_center > current_price else -1 if actual_center < current_price else 0
if predicted_direction == actual_direction:
reward = snapshot.confidence # Positive reward scaled by confidence
else:
reward = -snapshot.confidence # Negative reward scaled by confidence
# Additional reward based on range accuracy
range_overlap = self._calculate_range_overlap(
(snapshot.predicted_min_price, snapshot.predicted_max_price),
(snapshot.actual_min_price, snapshot.actual_max_price)
)
reward += range_overlap * 0.5 # Bonus for accurate range prediction
# Create next state (simplified)
next_state = state.copy()
experiences.append((state, action, reward, next_state, True)) # done=True
if len(experiences) < self.min_batch_size:
return {'error': 'Insufficient training data'}
# Add experiences to RL agent memory
experiences_added = 0
for state, action, reward, next_state, done in experiences:
try:
if hasattr(rl_agent, 'store_experience'):
rl_agent.store_experience(
state=np.array(state),
action=action,
reward=reward,
next_state=np.array(next_state),
done=done
)
experiences_added += 1
elif hasattr(rl_agent, 'remember'):
rl_agent.remember(np.array(state), action, reward, np.array(next_state), done)
experiences_added += 1
except Exception as e:
logger.debug(f"Error adding RL experience: {e}")
# Perform training steps
training_losses = []
if hasattr(rl_agent, 'replay') and experiences_added > 0:
try:
for _ in range(min(5, experiences_added // 8)): # Conservative training
loss = rl_agent.replay(batch_size=min(32, experiences_added))
if loss is not None:
training_losses.append(loss)
except Exception as e:
logger.debug(f"RL training step failed: {e}")
avg_loss = np.mean(training_losses) if training_losses else 0.0
results = {
'experiences_added': experiences_added,
'training_steps': len(training_losses),
'avg_loss': avg_loss,
'samples_used': len(experiences)
}
logger.info(f"RL training completed: {experiences_added} experiences, avg_loss={avg_loss:.4f}")
return results
except Exception as e:
logger.error(f"Error training RL model: {e}")
return {'error': str(e)}
def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float:
"""Calculate overlap between two price ranges (0.0 to 1.0)"""
try:
min1, max1 = range1
min2, max2 = range2
# Find overlap
overlap_min = max(min1, min2)
overlap_max = min(max1, max2)
if overlap_max <= overlap_min:
return 0.0
overlap_size = overlap_max - overlap_min
union_size = max(max1, max2) - min(min1, min2)
return overlap_size / union_size if union_size > 0 else 0.0
except Exception:
return 0.0
def force_training_session(self, horizon_minutes: Optional[int] = None,
symbol: Optional[str] = None) -> Dict[str, Any]:
"""Force a training session for specific parameters"""
try:
logger.info(f"Forcing training session: horizon={horizon_minutes}, symbol={symbol}")
results = {}
horizons = [horizon_minutes] if horizon_minutes else [1, 5, 15, 60]
symbols = [symbol] if symbol else ['ETH/USDT', 'BTC/USDT']
for h in horizons:
for s in symbols:
try:
horizon_results = self._train_horizon_models(h, s)
if horizon_results:
results[f"{h}m_{s}"] = horizon_results
except Exception as e:
logger.error(f"Error in forced training for {h}m {s}: {e}")
return results
except Exception as e:
logger.error(f"Error in forced training session: {e}")
return {'error': str(e)}
def get_training_stats(self) -> Dict[str, Any]:
"""Get training statistics"""
stats = dict(self.training_stats)
stats['is_training_active'] = self.training_active
# Calculate averages
for model_type in ['cnn', 'rl']:
if stats['training_accuracy'][model_type]:
stats[f'{model_type}_avg_accuracy'] = np.mean(stats['training_accuracy'][model_type])
else:
stats[f'{model_type}_avg_accuracy'] = 0.0
if stats['loss_history'][model_type]:
stats[f'{model_type}_avg_loss'] = np.mean(stats['loss_history'][model_type])
else:
stats[f'{model_type}_avg_loss'] = 0.0
return stats
def validate_recent_predictions(self):
"""Validate predictions that should have outcomes available"""
try:
# Get pending snapshots
pending_snapshots = self.snapshot_storage.get_pending_validation_snapshots()
if not pending_snapshots:
return
logger.info(f"Validating {len(pending_snapshots)} pending predictions")
# Group by symbol for efficient data access
by_symbol = defaultdict(list)
for snapshot in pending_snapshots:
by_symbol[snapshot.symbol].append(snapshot)
# Validate each symbol
for symbol, snapshots in by_symbol.items():
try:
self._validate_symbol_predictions(symbol, snapshots)
except Exception as e:
logger.error(f"Error validating predictions for {symbol}: {e}")
except Exception as e:
logger.error(f"Error validating recent predictions: {e}")
def _validate_symbol_predictions(self, symbol: str, snapshots: List[PredictionSnapshot]):
"""Validate predictions for a specific symbol"""
try:
# Get historical data for the validation period
# This is a simplified approach - in practice you'd need to get the price range
# during the prediction horizon
for snapshot in snapshots:
try:
# For now, use a simple validation approach
# In a real implementation, you'd query historical data for the exact time range
# and calculate actual min/max prices during the prediction horizon
# Simplified: assume current price as both min and max (not accurate but functional)
current_time = datetime.now()
current_price = snapshot.current_price # Placeholder
# Update snapshot with "outcome"
self.snapshot_storage.update_snapshot_outcome(
snapshot.prediction_id,
current_price, # actual_min
current_price, # actual_max
current_time
)
logger.debug(f"Validated prediction {snapshot.prediction_id}")
except Exception as e:
logger.error(f"Error validating snapshot {snapshot.prediction_id}: {e}")
except Exception as e:
logger.error(f"Error validating symbol predictions for {symbol}: {e}")