Files
gogo2/core/overnight_training_coordinator.py
Dobromir Popov 6d55061e86 wip training
2025-07-17 02:51:20 +03:00

710 lines
29 KiB
Python

"""
Overnight Training Coordinator
This module coordinates comprehensive training for CNN and COB RL models during overnight sessions.
It ensures that:
1. Training passes occur on each signal when predictions change
2. Trades are executed and recorded in simulation mode
3. Performance statistics are tracked and logged
4. Models learn from both successful and unsuccessful trades
"""
import logging
import time
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field
from collections import deque
import numpy as np
import json
import os
logger = logging.getLogger(__name__)
@dataclass
class TrainingSession:
"""Represents a training session for a model"""
model_name: str
symbol: str
start_time: datetime
end_time: Optional[datetime] = None
training_samples: int = 0
initial_loss: Optional[float] = None
final_loss: Optional[float] = None
improvement: Optional[float] = None
trades_executed: int = 0
successful_trades: int = 0
total_pnl: float = 0.0
@dataclass
class SignalTradeRecord:
"""Records a signal and its corresponding trade execution"""
timestamp: datetime
symbol: str
signal_action: str
signal_confidence: float
model_source: str
executed: bool = False
execution_price: Optional[float] = None
trade_pnl: Optional[float] = None
training_triggered: bool = False
training_loss: Optional[float] = None
class OvernightTrainingCoordinator:
"""
Coordinates comprehensive overnight training for all models
"""
def __init__(self, orchestrator, data_provider, trading_executor, dashboard=None):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.trading_executor = trading_executor
self.dashboard = dashboard
# Training configuration
self.config = {
'training_on_signal_change': True, # Train when prediction changes
'min_confidence_for_trade': 0.3, # Minimum confidence to execute trade
'max_trades_per_hour': 20, # Rate limiting
'training_batch_size': 32, # Training batch size
'performance_tracking_window': 100, # Number of trades to track for performance
'model_checkpoint_interval': 50, # Save checkpoints every N trades
}
# State tracking
self.is_running = False
self.training_thread = None
self.last_predictions: Dict[str, Dict[str, Any]] = {} # {symbol: {model: prediction}}
self.signal_trade_records: deque = deque(maxlen=1000)
self.training_sessions: Dict[str, TrainingSession] = {}
# Performance tracking
self.performance_stats = {
'total_signals': 0,
'total_trades': 0,
'successful_trades': 0,
'total_pnl': 0.0,
'training_sessions': 0,
'models_trained': set(),
'hourly_stats': deque(maxlen=24) # Last 24 hours
}
# Rate limiting
self.last_trade_time: Dict[str, datetime] = {}
self.trades_this_hour: Dict[str, int] = {}
self.hour_reset_time = datetime.now().replace(minute=0, second=0, microsecond=0)
logger.info("Overnight Training Coordinator initialized")
def start_overnight_training(self):
"""Start the overnight training session"""
if self.is_running:
logger.warning("Training coordinator already running")
return
self.is_running = True
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
self.training_thread.start()
logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED")
logger.info("=" * 60)
logger.info("Features enabled:")
logger.info("✅ CNN training on signal changes")
logger.info("✅ COB RL training on market microstructure")
logger.info("✅ Trade execution and recording")
logger.info("✅ Performance tracking and statistics")
logger.info("✅ Model checkpointing")
logger.info("=" * 60)
def stop_overnight_training(self):
"""Stop the overnight training session"""
self.is_running = False
if self.training_thread:
self.training_thread.join(timeout=10)
# Generate final report
self._generate_training_report()
logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED")
def _training_loop(self):
"""Main training loop that monitors signals and triggers training"""
while self.is_running:
try:
# Reset hourly counters if needed
self._reset_hourly_counters()
# Process signals from orchestrator
self._process_orchestrator_signals()
# Check for model training opportunities
self._check_training_opportunities()
# Update performance statistics
self._update_performance_stats()
# Sleep briefly to avoid overwhelming the system
time.sleep(0.5)
except Exception as e:
logger.error(f"Error in training loop: {e}")
time.sleep(5)
def _process_orchestrator_signals(self):
"""Process signals from the orchestrator and trigger training/trading"""
try:
# Get recent decisions from orchestrator
if not hasattr(self.orchestrator, 'recent_decisions'):
return
for symbol in self.orchestrator.symbols:
if symbol not in self.orchestrator.recent_decisions:
continue
recent_decisions = self.orchestrator.recent_decisions[symbol]
if not recent_decisions:
continue
# Get the latest decision
latest_decision = recent_decisions[-1]
# Check if this is a new signal that requires processing
if self._is_new_signal_requiring_action(symbol, latest_decision):
self._process_new_signal(symbol, latest_decision)
except Exception as e:
logger.error(f"Error processing orchestrator signals: {e}")
def _is_new_signal_requiring_action(self, symbol: str, decision) -> bool:
"""Check if this signal requires training or trading action"""
try:
# Get current prediction for comparison
current_action = decision.action
current_confidence = decision.confidence
current_time = decision.timestamp
# Check if we have a previous prediction for this symbol
if symbol not in self.last_predictions:
self.last_predictions[symbol] = {}
# Check if prediction has changed significantly
last_action = self.last_predictions[symbol].get('action')
last_confidence = self.last_predictions[symbol].get('confidence', 0.0)
last_time = self.last_predictions[symbol].get('timestamp')
# Determine if action is required
action_changed = last_action != current_action
confidence_changed = abs(current_confidence - last_confidence) > 0.1
time_elapsed = not last_time or (current_time - last_time).total_seconds() > 30
# Update last prediction
self.last_predictions[symbol] = {
'action': current_action,
'confidence': current_confidence,
'timestamp': current_time
}
return action_changed or confidence_changed or time_elapsed
except Exception as e:
logger.error(f"Error checking if signal requires action: {e}")
return False
def _process_new_signal(self, symbol: str, decision):
"""Process a new signal by triggering training and potentially executing trade"""
try:
signal_record = SignalTradeRecord(
timestamp=decision.timestamp,
symbol=symbol,
signal_action=decision.action,
signal_confidence=decision.confidence,
model_source=getattr(decision, 'reasoning', {}).get('primary_model', 'orchestrator')
)
# 1. Trigger training on signal change
if self.config['training_on_signal_change']:
training_loss = self._trigger_model_training(symbol, decision)
signal_record.training_triggered = True
signal_record.training_loss = training_loss
# 2. Execute trade if confidence is sufficient
if (decision.confidence >= self.config['min_confidence_for_trade'] and
decision.action in ['BUY', 'SELL'] and
self._can_execute_trade(symbol)):
trade_executed, execution_price, trade_pnl = self._execute_signal_trade(symbol, decision)
signal_record.executed = trade_executed
signal_record.execution_price = execution_price
signal_record.trade_pnl = trade_pnl
# Update performance stats
self.performance_stats['total_trades'] += 1
if trade_pnl and trade_pnl > 0:
self.performance_stats['successful_trades'] += 1
if trade_pnl:
self.performance_stats['total_pnl'] += trade_pnl
# 3. Record the signal
self.signal_trade_records.append(signal_record)
self.performance_stats['total_signals'] += 1
# 4. Log the action
status = "EXECUTED" if signal_record.executed else "SIGNAL_ONLY"
logger.info(f"[{status}] {symbol} {decision.action} "
f"(conf: {decision.confidence:.3f}, "
f"training: {'' if signal_record.training_triggered else ''}, "
f"pnl: {signal_record.trade_pnl:.2f if signal_record.trade_pnl else 'N/A'})")
except Exception as e:
logger.error(f"Error processing new signal for {symbol}: {e}")
def _trigger_model_training(self, symbol: str, decision) -> Optional[float]:
"""Trigger training for all relevant models"""
try:
training_losses = []
# 1. Train CNN model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
cnn_loss = self._train_cnn_model(symbol, decision)
if cnn_loss is not None:
training_losses.append(cnn_loss)
self.performance_stats['models_trained'].add('CNN')
# 2. Train COB RL model
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
cob_rl_loss = self._train_cob_rl_model(symbol, decision)
if cob_rl_loss is not None:
training_losses.append(cob_rl_loss)
self.performance_stats['models_trained'].add('COB_RL')
# 3. Train DQN model
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
dqn_loss = self._train_dqn_model(symbol, decision)
if dqn_loss is not None:
training_losses.append(dqn_loss)
self.performance_stats['models_trained'].add('DQN')
# Return average loss
return np.mean(training_losses) if training_losses else None
except Exception as e:
logger.error(f"Error triggering model training: {e}")
return None
def _train_cnn_model(self, symbol: str, decision) -> Optional[float]:
"""Train CNN model on current market data"""
try:
# Get market data for training
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
if df is None or len(df) < 50:
return None
# Prepare training data
features = self._prepare_cnn_features(df)
target = self._prepare_cnn_target(decision)
if features is None or target is None:
return None
# Train the model
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
loss = self.orchestrator.cnn_model.train_on_batch(features, target)
logger.debug(f"CNN training loss for {symbol}: {loss:.4f}")
return loss
return None
except Exception as e:
logger.error(f"Error training CNN model: {e}")
return None
def _train_cob_rl_model(self, symbol: str, decision) -> Optional[float]:
"""Train COB RL model on market microstructure data"""
try:
# Get COB data if available
if not hasattr(self.dashboard, 'latest_cob_data') or symbol not in self.dashboard.latest_cob_data:
return None
cob_data = self.dashboard.latest_cob_data[symbol]
# Prepare COB features
features = self._prepare_cob_features(cob_data)
reward = self._calculate_cob_reward(decision)
if features is None:
return None
# Train the model
if hasattr(self.orchestrator.cob_rl_agent, 'train'):
loss = self.orchestrator.cob_rl_agent.train(features, reward)
logger.debug(f"COB RL training loss for {symbol}: {loss:.4f}")
return loss
return None
except Exception as e:
logger.error(f"Error training COB RL model: {e}")
return None
def _train_dqn_model(self, symbol: str, decision) -> Optional[float]:
"""Train DQN model on trading decision"""
try:
# Get state features
state_features = self._prepare_dqn_state(symbol)
action = self._map_action_to_index(decision.action)
reward = decision.confidence # Use confidence as immediate reward
if state_features is None:
return None
# Add experience to replay buffer
if hasattr(self.orchestrator.rl_agent, 'remember'):
# We'll use a dummy next_state for now
next_state = state_features # Simplified
done = False
self.orchestrator.rl_agent.remember(state_features, action, reward, next_state, done)
# Train if we have enough experiences
if hasattr(self.orchestrator.rl_agent, 'replay'):
loss = self.orchestrator.rl_agent.replay()
if loss is not None:
logger.debug(f"DQN training loss for {symbol}: {loss:.4f}")
return loss
return None
except Exception as e:
logger.error(f"Error training DQN model: {e}")
return None
def _execute_signal_trade(self, symbol: str, decision) -> Tuple[bool, Optional[float], Optional[float]]:
"""Execute a trade based on the signal"""
try:
if not self.trading_executor:
return False, None, None
# Get current price
current_price = self.data_provider.get_current_price(symbol)
if not current_price:
return False, None, None
# Execute the trade
success = self.trading_executor.execute_signal(
symbol=symbol,
action=decision.action,
confidence=decision.confidence,
current_price=current_price
)
if success:
# Calculate PnL (simplified - in real implementation this would be more complex)
trade_pnl = self._calculate_trade_pnl(symbol, decision.action, current_price)
# Update rate limiting
self.last_trade_time[symbol] = datetime.now()
if symbol not in self.trades_this_hour:
self.trades_this_hour[symbol] = 0
self.trades_this_hour[symbol] += 1
return True, current_price, trade_pnl
return False, None, None
except Exception as e:
logger.error(f"Error executing signal trade: {e}")
return False, None, None
def _can_execute_trade(self, symbol: str) -> bool:
"""Check if we can execute a trade based on rate limiting"""
try:
# Check hourly limit
if symbol in self.trades_this_hour:
if self.trades_this_hour[symbol] >= self.config['max_trades_per_hour']:
return False
# Check minimum time between trades (30 seconds)
if symbol in self.last_trade_time:
time_since_last = (datetime.now() - self.last_trade_time[symbol]).total_seconds()
if time_since_last < 30:
return False
return True
except Exception as e:
logger.error(f"Error checking if can execute trade: {e}")
return False
def _prepare_cnn_features(self, df) -> Optional[np.ndarray]:
"""Prepare features for CNN training"""
try:
# Use OHLCV data as features
features = df[['open', 'high', 'low', 'close', 'volume']].values
# Normalize features
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
# Reshape for CNN (add batch and channel dimensions)
features = features.reshape(1, features.shape[0], features.shape[1])
return features.astype(np.float32)
except Exception as e:
logger.error(f"Error preparing CNN features: {e}")
return None
def _prepare_cnn_target(self, decision) -> Optional[np.ndarray]:
"""Prepare target for CNN training"""
try:
# Map action to target
action_map = {'BUY': [1, 0, 0], 'SELL': [0, 1, 0], 'HOLD': [0, 0, 1]}
target = action_map.get(decision.action, [0, 0, 1])
return np.array([target], dtype=np.float32)
except Exception as e:
logger.error(f"Error preparing CNN target: {e}")
return None
def _prepare_cob_features(self, cob_data) -> Optional[np.ndarray]:
"""Prepare COB features for training"""
try:
# Extract key COB features
features = []
# Order book imbalance
imbalance = cob_data.get('stats', {}).get('imbalance', 0)
features.append(imbalance)
# Bid/Ask liquidity
bid_liquidity = cob_data.get('stats', {}).get('bid_liquidity', 0)
ask_liquidity = cob_data.get('stats', {}).get('ask_liquidity', 0)
features.extend([bid_liquidity, ask_liquidity])
# Spread
spread = cob_data.get('stats', {}).get('spread_bps', 0)
features.append(spread)
# Pad to expected size (2000 features for COB RL)
while len(features) < 2000:
features.append(0.0)
return np.array(features[:2000], dtype=np.float32)
except Exception as e:
logger.error(f"Error preparing COB features: {e}")
return None
def _calculate_cob_reward(self, decision) -> float:
"""Calculate reward for COB RL training"""
try:
# Use confidence as base reward
base_reward = decision.confidence
# Adjust based on action
if decision.action in ['BUY', 'SELL']:
return base_reward
else:
return base_reward * 0.1 # Lower reward for HOLD
except Exception as e:
logger.error(f"Error calculating COB reward: {e}")
return 0.0
def _prepare_dqn_state(self, symbol: str) -> Optional[np.ndarray]:
"""Prepare state features for DQN training"""
try:
# Get market data
df = self.data_provider.get_historical_data(symbol, '1m', limit=50)
if df is None or len(df) < 10:
return None
# Prepare basic features
features = []
# Price features
close_prices = df['close'].values
features.extend(close_prices[-10:]) # Last 10 prices
# Technical indicators
if len(close_prices) >= 20:
sma_20 = np.mean(close_prices[-20:])
features.append(sma_20)
else:
features.append(close_prices[-1])
# Volume features
volumes = df['volume'].values
features.extend(volumes[-5:]) # Last 5 volumes
# Pad to expected size (100 features for DQN)
while len(features) < 100:
features.append(0.0)
return np.array(features[:100], dtype=np.float32)
except Exception as e:
logger.error(f"Error preparing DQN state: {e}")
return None
def _map_action_to_index(self, action: str) -> int:
"""Map action string to index"""
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
return action_map.get(action, 2)
def _calculate_trade_pnl(self, symbol: str, action: str, price: float) -> float:
"""Calculate simplified PnL for a trade"""
try:
# This is a simplified PnL calculation
# In a real implementation, this would track actual position changes
# Get previous price for comparison
df = self.data_provider.get_historical_data(symbol, '1m', limit=2)
if df is None or len(df) < 2:
return 0.0
prev_price = df['close'].iloc[-2]
current_price = price
# Calculate price change
price_change = (current_price - prev_price) / prev_price
# Apply action direction
if action == 'BUY':
return price_change * 100 # Simplified PnL
elif action == 'SELL':
return -price_change * 100 # Simplified PnL
else:
return 0.0
except Exception as e:
logger.error(f"Error calculating trade PnL: {e}")
return 0.0
def _check_training_opportunities(self):
"""Check for additional training opportunities"""
try:
# Check if we should save model checkpoints
if (self.performance_stats['total_trades'] > 0 and
self.performance_stats['total_trades'] % self.config['model_checkpoint_interval'] == 0):
self._save_model_checkpoints()
except Exception as e:
logger.error(f"Error checking training opportunities: {e}")
def _save_model_checkpoints(self):
"""Save model checkpoints"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save CNN model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
if hasattr(self.orchestrator.cnn_model, 'save'):
checkpoint_path = f"models/overnight_cnn_{timestamp}.pth"
self.orchestrator.cnn_model.save(checkpoint_path)
logger.info(f"CNN checkpoint saved: {checkpoint_path}")
# Save COB RL model
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
if hasattr(self.orchestrator.cob_rl_agent, 'save_model'):
checkpoint_path = f"models/overnight_cob_rl_{timestamp}.pth"
self.orchestrator.cob_rl_agent.save_model(checkpoint_path)
logger.info(f"COB RL checkpoint saved: {checkpoint_path}")
# Save DQN model
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
if hasattr(self.orchestrator.rl_agent, 'save'):
checkpoint_path = f"models/overnight_dqn_{timestamp}.pth"
self.orchestrator.rl_agent.save(checkpoint_path)
logger.info(f"DQN checkpoint saved: {checkpoint_path}")
except Exception as e:
logger.error(f"Error saving model checkpoints: {e}")
def _reset_hourly_counters(self):
"""Reset hourly trade counters"""
try:
current_hour = datetime.now().replace(minute=0, second=0, microsecond=0)
if current_hour > self.hour_reset_time:
self.trades_this_hour = {}
self.hour_reset_time = current_hour
logger.info("Hourly trade counters reset")
except Exception as e:
logger.error(f"Error resetting hourly counters: {e}")
def _update_performance_stats(self):
"""Update performance statistics"""
try:
# Update hourly stats every hour
current_hour = datetime.now().replace(minute=0, second=0, microsecond=0)
# Check if we need to add a new hourly stat
if not self.performance_stats['hourly_stats'] or self.performance_stats['hourly_stats'][-1]['hour'] != current_hour:
hourly_stat = {
'hour': current_hour,
'signals': 0,
'trades': 0,
'pnl': 0.0,
'models_trained': set()
}
self.performance_stats['hourly_stats'].append(hourly_stat)
except Exception as e:
logger.error(f"Error updating performance stats: {e}")
def _generate_training_report(self):
"""Generate a comprehensive training report"""
try:
logger.info("=" * 80)
logger.info("🌅 OVERNIGHT TRAINING SESSION REPORT")
logger.info("=" * 80)
# Overall statistics
logger.info(f"📊 OVERALL STATISTICS:")
logger.info(f" Total Signals Processed: {self.performance_stats['total_signals']}")
logger.info(f" Total Trades Executed: {self.performance_stats['total_trades']}")
logger.info(f" Successful Trades: {self.performance_stats['successful_trades']}")
logger.info(f" Success Rate: {(self.performance_stats['successful_trades'] / max(1, self.performance_stats['total_trades']) * 100):.1f}%")
logger.info(f" Total P&L: ${self.performance_stats['total_pnl']:.2f}")
# Model training statistics
logger.info(f"🧠 MODEL TRAINING:")
logger.info(f" Models Trained: {', '.join(self.performance_stats['models_trained'])}")
logger.info(f" Training Sessions: {len(self.training_sessions)}")
# Recent performance
if self.signal_trade_records:
recent_records = list(self.signal_trade_records)[-20:] # Last 20 records
executed_trades = [r for r in recent_records if r.executed]
successful_trades = [r for r in executed_trades if r.trade_pnl and r.trade_pnl > 0]
logger.info(f"📈 RECENT PERFORMANCE (Last 20 signals):")
logger.info(f" Signals: {len(recent_records)}")
logger.info(f" Executed: {len(executed_trades)}")
logger.info(f" Successful: {len(successful_trades)}")
if executed_trades:
recent_pnl = sum(r.trade_pnl for r in executed_trades if r.trade_pnl)
logger.info(f" Recent P&L: ${recent_pnl:.2f}")
logger.info("=" * 80)
except Exception as e:
logger.error(f"Error generating training report: {e}")
def get_performance_summary(self) -> Dict[str, Any]:
"""Get current performance summary"""
try:
return {
'total_signals': self.performance_stats['total_signals'],
'total_trades': self.performance_stats['total_trades'],
'successful_trades': self.performance_stats['successful_trades'],
'success_rate': (self.performance_stats['successful_trades'] / max(1, self.performance_stats['total_trades'])),
'total_pnl': self.performance_stats['total_pnl'],
'models_trained': list(self.performance_stats['models_trained']),
'is_running': self.is_running,
'recent_signals': len(self.signal_trade_records)
}
except Exception as e:
logger.error(f"Error getting performance summary: {e}")
return {}