454 lines
19 KiB
Python
454 lines
19 KiB
Python
"""
|
|
Retrospective Training System
|
|
|
|
This module implements a retrospective training system that:
|
|
1. Triggers training when trades close with known P&L outcomes
|
|
2. Uses captured model inputs from trade entry to train models
|
|
3. Optimizes for profit by learning from profitable vs unprofitable patterns
|
|
4. Supports simultaneous inference and training without weight reloading
|
|
5. Implements reinforcement learning with immediate reward feedback
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
import queue
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any, Callable
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
from collections import deque
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class TrainingCase:
|
|
"""Represents a completed trade case for retrospective training"""
|
|
case_id: str
|
|
symbol: str
|
|
action: str # 'BUY' or 'SELL'
|
|
entry_price: float
|
|
exit_price: float
|
|
entry_time: datetime
|
|
exit_time: datetime
|
|
pnl: float
|
|
fees: float
|
|
confidence: float
|
|
model_inputs: Dict[str, Any]
|
|
market_state: Dict[str, Any]
|
|
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
|
|
reward_signal: float # Scaled reward for RL training
|
|
leverage: float = 1.0
|
|
|
|
class RetrospectiveTrainer:
|
|
"""Retrospective training system for real-time model optimization"""
|
|
|
|
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
|
|
"""Initialize the retrospective trainer"""
|
|
self.orchestrator = orchestrator
|
|
self.config = config or {}
|
|
|
|
# Training configuration
|
|
self.batch_size = self.config.get('batch_size', 32)
|
|
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
|
|
self.profit_threshold = self.config.get('profit_threshold', 0.0)
|
|
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
|
|
self.max_training_cases = self.config.get('max_training_cases', 1000)
|
|
|
|
# Training state
|
|
self.training_queue = queue.Queue()
|
|
self.completed_cases = deque(maxlen=self.max_training_cases)
|
|
self.training_stats = {
|
|
'total_cases': 0,
|
|
'profitable_cases': 0,
|
|
'loss_cases': 0,
|
|
'breakeven_cases': 0,
|
|
'avg_profit': 0.0,
|
|
'last_training_time': datetime.now(),
|
|
'training_sessions': 0,
|
|
'model_updates': 0
|
|
}
|
|
|
|
# Threading
|
|
self.training_thread = None
|
|
self.is_training_active = False
|
|
self.training_lock = threading.Lock()
|
|
|
|
logger.info("RetrospectiveTrainer initialized")
|
|
logger.info(f"Configuration: batch_size={self.batch_size}, "
|
|
f"min_cases={self.min_cases_for_training}, "
|
|
f"training_freq={self.training_frequency}s")
|
|
|
|
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
|
|
"""Add a completed trade for retrospective training"""
|
|
try:
|
|
# Create training case from trade record
|
|
case = self._create_training_case(trade_record, model_inputs)
|
|
if case is None:
|
|
return False
|
|
|
|
# Add to completed cases
|
|
self.completed_cases.append(case)
|
|
self.training_queue.put(case)
|
|
|
|
# Update statistics
|
|
self.training_stats['total_cases'] += 1
|
|
if case.outcome_label == 1: # Profit
|
|
self.training_stats['profitable_cases'] += 1
|
|
elif case.outcome_label == 0: # Loss
|
|
self.training_stats['loss_cases'] += 1
|
|
else: # Breakeven
|
|
self.training_stats['breakeven_cases'] += 1
|
|
|
|
# Calculate running average profit
|
|
total_pnl = sum(c.pnl for c in self.completed_cases)
|
|
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
|
|
|
|
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
|
|
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
|
|
|
|
# Trigger training if we have enough cases
|
|
self._maybe_trigger_training()
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding completed trade for retrospective training: {e}")
|
|
return False
|
|
|
|
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
|
|
"""Create a training case from trade record and model inputs"""
|
|
try:
|
|
# Extract trade information
|
|
symbol = trade_record.get('symbol', 'UNKNOWN')
|
|
side = trade_record.get('side', 'UNKNOWN')
|
|
pnl = trade_record.get('pnl', 0.0)
|
|
fees = trade_record.get('fees', 0.0)
|
|
confidence = trade_record.get('confidence', 0.0)
|
|
|
|
# Calculate net P&L after fees
|
|
net_pnl = pnl - fees
|
|
|
|
# Determine outcome label and reward signal
|
|
if net_pnl > self.profit_threshold:
|
|
outcome_label = 1 # Profitable
|
|
# Scale reward by profit magnitude and confidence
|
|
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
|
|
elif net_pnl < -self.profit_threshold:
|
|
outcome_label = 0 # Loss
|
|
# Negative reward scaled by loss magnitude
|
|
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
|
|
else:
|
|
outcome_label = 2 # Breakeven
|
|
reward_signal = 0.0
|
|
|
|
# Create case ID
|
|
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
|
|
|
|
# Create training case
|
|
case = TrainingCase(
|
|
case_id=case_id,
|
|
symbol=symbol,
|
|
action=side,
|
|
entry_price=trade_record.get('entry_price', 0.0),
|
|
exit_price=trade_record.get('exit_price', 0.0),
|
|
entry_time=trade_record.get('entry_time', datetime.now()),
|
|
exit_time=trade_record.get('exit_time', datetime.now()),
|
|
pnl=net_pnl,
|
|
fees=fees,
|
|
confidence=confidence,
|
|
model_inputs=model_inputs,
|
|
market_state=model_inputs.get('market_state', {}),
|
|
outcome_label=outcome_label,
|
|
reward_signal=reward_signal,
|
|
leverage=trade_record.get('leverage', 1.0)
|
|
)
|
|
|
|
return case
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating training case: {e}")
|
|
return None
|
|
|
|
def _maybe_trigger_training(self):
|
|
"""Check if we should trigger a training session"""
|
|
try:
|
|
# Check if we have enough cases
|
|
if len(self.completed_cases) < self.min_cases_for_training:
|
|
return
|
|
|
|
# Check if enough time has passed since last training
|
|
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
|
|
if time_since_last < self.training_frequency:
|
|
return
|
|
|
|
# Check if training thread is not already running
|
|
if self.is_training_active:
|
|
logger.debug("Training already in progress, skipping trigger")
|
|
return
|
|
|
|
# Start training in background thread
|
|
self._start_training_session()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking training trigger: {e}")
|
|
|
|
def _start_training_session(self):
|
|
"""Start a training session in background thread"""
|
|
try:
|
|
if self.training_thread and self.training_thread.is_alive():
|
|
logger.debug("Training thread already running")
|
|
return
|
|
|
|
self.training_thread = threading.Thread(
|
|
target=self._run_training_session,
|
|
daemon=True,
|
|
name="RetrospectiveTrainer"
|
|
)
|
|
self.training_thread.start()
|
|
logger.info("RETROSPECTIVE: Started training session")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting training session: {e}")
|
|
|
|
def _run_training_session(self):
|
|
"""Run a complete training session"""
|
|
try:
|
|
with self.training_lock:
|
|
self.is_training_active = True
|
|
start_time = time.time()
|
|
|
|
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
|
|
|
|
# Train models if orchestrator available
|
|
training_results = {}
|
|
if self.orchestrator:
|
|
training_results = self._train_models()
|
|
|
|
# Update statistics
|
|
self.training_stats['last_training_time'] = datetime.now()
|
|
self.training_stats['training_sessions'] += 1
|
|
self.training_stats['model_updates'] += len(training_results)
|
|
|
|
elapsed_time = time.time() - start_time
|
|
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in retrospective training session: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
self.is_training_active = False
|
|
|
|
def _train_models(self) -> Dict[str, Any]:
|
|
"""Train available models using retrospective data"""
|
|
results = {}
|
|
|
|
try:
|
|
# Prepare training data
|
|
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
|
|
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
|
|
|
|
if len(profitable_cases) == 0 and len(loss_cases) == 0:
|
|
return {'error': 'No labeled cases for training'}
|
|
|
|
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
|
|
|
|
# Train DQN agent if available
|
|
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
try:
|
|
dqn_result = self._train_dqn_retrospective()
|
|
results['dqn'] = dqn_result
|
|
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
|
|
except Exception as e:
|
|
logger.warning(f"DQN retrospective training failed: {e}")
|
|
results['dqn'] = {'error': str(e)}
|
|
|
|
# Train other models
|
|
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
|
try:
|
|
# Update extrema trainer with retrospective feedback
|
|
extrema_feedback = self._create_extrema_feedback()
|
|
if extrema_feedback:
|
|
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
|
|
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
|
|
except Exception as e:
|
|
logger.warning(f"Extrema retrospective training failed: {e}")
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training models retrospectively: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _train_dqn_retrospective(self) -> Dict[str, Any]:
|
|
"""Train DQN agent using retrospective experience replay"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
|
return {'error': 'DQN agent not available'}
|
|
|
|
dqn_agent = self.orchestrator.rl_agent
|
|
experiences_added = 0
|
|
|
|
# Add retrospective experiences to DQN replay buffer
|
|
for case in self.completed_cases:
|
|
try:
|
|
# Extract state from model inputs
|
|
state = self._extract_state_vector(case.model_inputs)
|
|
if state is None:
|
|
continue
|
|
|
|
# Action mapping: BUY=0, SELL=1
|
|
action = 0 if case.action == 'BUY' else 1
|
|
|
|
# Use reward signal as immediate reward
|
|
reward = case.reward_signal
|
|
|
|
# For retrospective training, next_state is None (terminal)
|
|
next_state = np.zeros_like(state) # Terminal state
|
|
done = True
|
|
|
|
# Add experience to DQN replay buffer
|
|
if hasattr(dqn_agent, 'add_experience'):
|
|
dqn_agent.add_experience(state, action, reward, next_state, done)
|
|
experiences_added += 1
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error adding DQN experience: {e}")
|
|
continue
|
|
|
|
# Train DQN if we have enough experiences
|
|
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
|
|
try:
|
|
# Perform multiple training steps on retrospective data
|
|
training_steps = min(10, experiences_added // 4) # Conservative training
|
|
for _ in range(training_steps):
|
|
loss = dqn_agent.train()
|
|
if loss is None:
|
|
break
|
|
|
|
return {
|
|
'experiences_added': experiences_added,
|
|
'training_steps': training_steps,
|
|
'method': 'retrospective_experience_replay'
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"DQN training step failed: {e}")
|
|
return {'experiences_added': experiences_added, 'training_error': str(e)}
|
|
|
|
return {'experiences_added': experiences_added, 'training_steps': 0}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in DQN retrospective training: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
|
|
"""Extract state vector for DQN training from model inputs"""
|
|
try:
|
|
# Try to get pre-built RL state
|
|
if 'dqn_state' in model_inputs:
|
|
state = model_inputs['dqn_state']
|
|
if isinstance(state, dict) and 'state_vector' in state:
|
|
return np.array(state['state_vector'])
|
|
|
|
# Build state from market features
|
|
market_state = model_inputs.get('market_state', {})
|
|
features = []
|
|
|
|
# Price features
|
|
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
|
|
features.append(market_state.get(key, 0.0))
|
|
|
|
# Volume features
|
|
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
|
|
features.append(market_state.get(key, 0.0))
|
|
|
|
# Technical indicators
|
|
indicators = model_inputs.get('technical_indicators', {})
|
|
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
|
|
features.append(indicators.get(key, 0.0))
|
|
|
|
if len(features) < 5: # Minimum required features
|
|
return None
|
|
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error extracting state vector: {e}")
|
|
return None
|
|
|
|
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
|
|
"""Create feedback data for extrema trainer"""
|
|
feedback = []
|
|
|
|
try:
|
|
for case in self.completed_cases:
|
|
if case.outcome_label in [0, 1]: # Only profit/loss cases
|
|
feedback_item = {
|
|
'symbol': case.symbol,
|
|
'action': case.action,
|
|
'entry_price': case.entry_price,
|
|
'exit_price': case.exit_price,
|
|
'was_profitable': case.outcome_label == 1,
|
|
'reward_signal': case.reward_signal,
|
|
'market_state': case.market_state
|
|
}
|
|
feedback.append(feedback_item)
|
|
|
|
return feedback
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating extrema feedback: {e}")
|
|
return []
|
|
|
|
def get_training_stats(self) -> Dict[str, Any]:
|
|
"""Get current training statistics"""
|
|
stats = self.training_stats.copy()
|
|
stats['total_cases_in_memory'] = len(self.completed_cases)
|
|
stats['training_queue_size'] = self.training_queue.qsize()
|
|
stats['is_training_active'] = self.is_training_active
|
|
|
|
# Calculate profit metrics
|
|
if len(self.completed_cases) > 0:
|
|
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
|
|
stats['profit_rate'] = profitable_count / len(self.completed_cases)
|
|
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
|
|
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
|
|
|
|
return stats
|
|
|
|
def force_training_session(self) -> bool:
|
|
"""Force a training session regardless of timing constraints"""
|
|
try:
|
|
if self.is_training_active:
|
|
logger.warning("Training already in progress")
|
|
return False
|
|
|
|
if len(self.completed_cases) < 1:
|
|
logger.warning("No completed cases available for training")
|
|
return False
|
|
|
|
logger.info("RETROSPECTIVE: Forcing training session")
|
|
self._start_training_session()
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error forcing training session: {e}")
|
|
return False
|
|
|
|
def stop(self):
|
|
"""Stop the retrospective trainer"""
|
|
try:
|
|
self.is_training_active = False
|
|
if self.training_thread and self.training_thread.is_alive():
|
|
self.training_thread.join(timeout=10)
|
|
logger.info("RetrospectiveTrainer stopped")
|
|
except Exception as e:
|
|
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
|
|
|
|
|
|
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
|
|
"""Factory function to create a RetrospectiveTrainer instance"""
|
|
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)
|