train retrospectively progress (wip)
This commit is contained in:
453
core/retrospective_trainer.py
Normal file
453
core/retrospective_trainer.py
Normal file
@ -0,0 +1,453 @@
|
||||
"""
|
||||
Retrospective Training System
|
||||
|
||||
This module implements a retrospective training system that:
|
||||
1. Triggers training when trades close with known P&L outcomes
|
||||
2. Uses captured model inputs from trade entry to train models
|
||||
3. Optimizes for profit by learning from profitable vs unprofitable patterns
|
||||
4. Supports simultaneous inference and training without weight reloading
|
||||
5. Implements reinforcement learning with immediate reward feedback
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainingCase:
|
||||
"""Represents a completed trade case for retrospective training"""
|
||||
case_id: str
|
||||
symbol: str
|
||||
action: str # 'BUY' or 'SELL'
|
||||
entry_price: float
|
||||
exit_price: float
|
||||
entry_time: datetime
|
||||
exit_time: datetime
|
||||
pnl: float
|
||||
fees: float
|
||||
confidence: float
|
||||
model_inputs: Dict[str, Any]
|
||||
market_state: Dict[str, Any]
|
||||
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
|
||||
reward_signal: float # Scaled reward for RL training
|
||||
leverage: float = 1.0
|
||||
|
||||
class RetrospectiveTrainer:
|
||||
"""Retrospective training system for real-time model optimization"""
|
||||
|
||||
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
|
||||
"""Initialize the retrospective trainer"""
|
||||
self.orchestrator = orchestrator
|
||||
self.config = config or {}
|
||||
|
||||
# Training configuration
|
||||
self.batch_size = self.config.get('batch_size', 32)
|
||||
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
|
||||
self.profit_threshold = self.config.get('profit_threshold', 0.0)
|
||||
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
|
||||
self.max_training_cases = self.config.get('max_training_cases', 1000)
|
||||
|
||||
# Training state
|
||||
self.training_queue = queue.Queue()
|
||||
self.completed_cases = deque(maxlen=self.max_training_cases)
|
||||
self.training_stats = {
|
||||
'total_cases': 0,
|
||||
'profitable_cases': 0,
|
||||
'loss_cases': 0,
|
||||
'breakeven_cases': 0,
|
||||
'avg_profit': 0.0,
|
||||
'last_training_time': datetime.now(),
|
||||
'training_sessions': 0,
|
||||
'model_updates': 0
|
||||
}
|
||||
|
||||
# Threading
|
||||
self.training_thread = None
|
||||
self.is_training_active = False
|
||||
self.training_lock = threading.Lock()
|
||||
|
||||
logger.info("RetrospectiveTrainer initialized")
|
||||
logger.info(f"Configuration: batch_size={self.batch_size}, "
|
||||
f"min_cases={self.min_cases_for_training}, "
|
||||
f"training_freq={self.training_frequency}s")
|
||||
|
||||
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
|
||||
"""Add a completed trade for retrospective training"""
|
||||
try:
|
||||
# Create training case from trade record
|
||||
case = self._create_training_case(trade_record, model_inputs)
|
||||
if case is None:
|
||||
return False
|
||||
|
||||
# Add to completed cases
|
||||
self.completed_cases.append(case)
|
||||
self.training_queue.put(case)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_cases'] += 1
|
||||
if case.outcome_label == 1: # Profit
|
||||
self.training_stats['profitable_cases'] += 1
|
||||
elif case.outcome_label == 0: # Loss
|
||||
self.training_stats['loss_cases'] += 1
|
||||
else: # Breakeven
|
||||
self.training_stats['breakeven_cases'] += 1
|
||||
|
||||
# Calculate running average profit
|
||||
total_pnl = sum(c.pnl for c in self.completed_cases)
|
||||
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
|
||||
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
|
||||
|
||||
# Trigger training if we have enough cases
|
||||
self._maybe_trigger_training()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding completed trade for retrospective training: {e}")
|
||||
return False
|
||||
|
||||
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
|
||||
"""Create a training case from trade record and model inputs"""
|
||||
try:
|
||||
# Extract trade information
|
||||
symbol = trade_record.get('symbol', 'UNKNOWN')
|
||||
side = trade_record.get('side', 'UNKNOWN')
|
||||
pnl = trade_record.get('pnl', 0.0)
|
||||
fees = trade_record.get('fees', 0.0)
|
||||
confidence = trade_record.get('confidence', 0.0)
|
||||
|
||||
# Calculate net P&L after fees
|
||||
net_pnl = pnl - fees
|
||||
|
||||
# Determine outcome label and reward signal
|
||||
if net_pnl > self.profit_threshold:
|
||||
outcome_label = 1 # Profitable
|
||||
# Scale reward by profit magnitude and confidence
|
||||
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
|
||||
elif net_pnl < -self.profit_threshold:
|
||||
outcome_label = 0 # Loss
|
||||
# Negative reward scaled by loss magnitude
|
||||
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
|
||||
else:
|
||||
outcome_label = 2 # Breakeven
|
||||
reward_signal = 0.0
|
||||
|
||||
# Create case ID
|
||||
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
|
||||
|
||||
# Create training case
|
||||
case = TrainingCase(
|
||||
case_id=case_id,
|
||||
symbol=symbol,
|
||||
action=side,
|
||||
entry_price=trade_record.get('entry_price', 0.0),
|
||||
exit_price=trade_record.get('exit_price', 0.0),
|
||||
entry_time=trade_record.get('entry_time', datetime.now()),
|
||||
exit_time=trade_record.get('exit_time', datetime.now()),
|
||||
pnl=net_pnl,
|
||||
fees=fees,
|
||||
confidence=confidence,
|
||||
model_inputs=model_inputs,
|
||||
market_state=model_inputs.get('market_state', {}),
|
||||
outcome_label=outcome_label,
|
||||
reward_signal=reward_signal,
|
||||
leverage=trade_record.get('leverage', 1.0)
|
||||
)
|
||||
|
||||
return case
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training case: {e}")
|
||||
return None
|
||||
|
||||
def _maybe_trigger_training(self):
|
||||
"""Check if we should trigger a training session"""
|
||||
try:
|
||||
# Check if we have enough cases
|
||||
if len(self.completed_cases) < self.min_cases_for_training:
|
||||
return
|
||||
|
||||
# Check if enough time has passed since last training
|
||||
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
|
||||
if time_since_last < self.training_frequency:
|
||||
return
|
||||
|
||||
# Check if training thread is not already running
|
||||
if self.is_training_active:
|
||||
logger.debug("Training already in progress, skipping trigger")
|
||||
return
|
||||
|
||||
# Start training in background thread
|
||||
self._start_training_session()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training trigger: {e}")
|
||||
|
||||
def _start_training_session(self):
|
||||
"""Start a training session in background thread"""
|
||||
try:
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
logger.debug("Training thread already running")
|
||||
return
|
||||
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._run_training_session,
|
||||
daemon=True,
|
||||
name="RetrospectiveTrainer"
|
||||
)
|
||||
self.training_thread.start()
|
||||
logger.info("RETROSPECTIVE: Started training session")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training session: {e}")
|
||||
|
||||
def _run_training_session(self):
|
||||
"""Run a complete training session"""
|
||||
try:
|
||||
with self.training_lock:
|
||||
self.is_training_active = True
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
|
||||
|
||||
# Train models if orchestrator available
|
||||
training_results = {}
|
||||
if self.orchestrator:
|
||||
training_results = self._train_models()
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['last_training_time'] = datetime.now()
|
||||
self.training_stats['training_sessions'] += 1
|
||||
self.training_stats['model_updates'] += len(training_results)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in retrospective training session: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
self.is_training_active = False
|
||||
|
||||
def _train_models(self) -> Dict[str, Any]:
|
||||
"""Train available models using retrospective data"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
|
||||
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
|
||||
|
||||
if len(profitable_cases) == 0 and len(loss_cases) == 0:
|
||||
return {'error': 'No labeled cases for training'}
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
|
||||
|
||||
# Train DQN agent if available
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
try:
|
||||
dqn_result = self._train_dqn_retrospective()
|
||||
results['dqn'] = dqn_result
|
||||
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
|
||||
except Exception as e:
|
||||
logger.warning(f"DQN retrospective training failed: {e}")
|
||||
results['dqn'] = {'error': str(e)}
|
||||
|
||||
# Train other models
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
||||
try:
|
||||
# Update extrema trainer with retrospective feedback
|
||||
extrema_feedback = self._create_extrema_feedback()
|
||||
if extrema_feedback:
|
||||
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
|
||||
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
|
||||
except Exception as e:
|
||||
logger.warning(f"Extrema retrospective training failed: {e}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training models retrospectively: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _train_dqn_retrospective(self) -> Dict[str, Any]:
|
||||
"""Train DQN agent using retrospective experience replay"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return {'error': 'DQN agent not available'}
|
||||
|
||||
dqn_agent = self.orchestrator.rl_agent
|
||||
experiences_added = 0
|
||||
|
||||
# Add retrospective experiences to DQN replay buffer
|
||||
for case in self.completed_cases:
|
||||
try:
|
||||
# Extract state from model inputs
|
||||
state = self._extract_state_vector(case.model_inputs)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
# Action mapping: BUY=0, SELL=1
|
||||
action = 0 if case.action == 'BUY' else 1
|
||||
|
||||
# Use reward signal as immediate reward
|
||||
reward = case.reward_signal
|
||||
|
||||
# For retrospective training, next_state is None (terminal)
|
||||
next_state = np.zeros_like(state) # Terminal state
|
||||
done = True
|
||||
|
||||
# Add experience to DQN replay buffer
|
||||
if hasattr(dqn_agent, 'add_experience'):
|
||||
dqn_agent.add_experience(state, action, reward, next_state, done)
|
||||
experiences_added += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding DQN experience: {e}")
|
||||
continue
|
||||
|
||||
# Train DQN if we have enough experiences
|
||||
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
|
||||
try:
|
||||
# Perform multiple training steps on retrospective data
|
||||
training_steps = min(10, experiences_added // 4) # Conservative training
|
||||
for _ in range(training_steps):
|
||||
loss = dqn_agent.train()
|
||||
if loss is None:
|
||||
break
|
||||
|
||||
return {
|
||||
'experiences_added': experiences_added,
|
||||
'training_steps': training_steps,
|
||||
'method': 'retrospective_experience_replay'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"DQN training step failed: {e}")
|
||||
return {'experiences_added': experiences_added, 'training_error': str(e)}
|
||||
|
||||
return {'experiences_added': experiences_added, 'training_steps': 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN retrospective training: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
|
||||
"""Extract state vector for DQN training from model inputs"""
|
||||
try:
|
||||
# Try to get pre-built RL state
|
||||
if 'dqn_state' in model_inputs:
|
||||
state = model_inputs['dqn_state']
|
||||
if isinstance(state, dict) and 'state_vector' in state:
|
||||
return np.array(state['state_vector'])
|
||||
|
||||
# Build state from market features
|
||||
market_state = model_inputs.get('market_state', {})
|
||||
features = []
|
||||
|
||||
# Price features
|
||||
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
|
||||
features.append(market_state.get(key, 0.0))
|
||||
|
||||
# Volume features
|
||||
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
|
||||
features.append(market_state.get(key, 0.0))
|
||||
|
||||
# Technical indicators
|
||||
indicators = model_inputs.get('technical_indicators', {})
|
||||
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
|
||||
features.append(indicators.get(key, 0.0))
|
||||
|
||||
if len(features) < 5: # Minimum required features
|
||||
return None
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting state vector: {e}")
|
||||
return None
|
||||
|
||||
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
|
||||
"""Create feedback data for extrema trainer"""
|
||||
feedback = []
|
||||
|
||||
try:
|
||||
for case in self.completed_cases:
|
||||
if case.outcome_label in [0, 1]: # Only profit/loss cases
|
||||
feedback_item = {
|
||||
'symbol': case.symbol,
|
||||
'action': case.action,
|
||||
'entry_price': case.entry_price,
|
||||
'exit_price': case.exit_price,
|
||||
'was_profitable': case.outcome_label == 1,
|
||||
'reward_signal': case.reward_signal,
|
||||
'market_state': case.market_state
|
||||
}
|
||||
feedback.append(feedback_item)
|
||||
|
||||
return feedback
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating extrema feedback: {e}")
|
||||
return []
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get current training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
stats['total_cases_in_memory'] = len(self.completed_cases)
|
||||
stats['training_queue_size'] = self.training_queue.qsize()
|
||||
stats['is_training_active'] = self.is_training_active
|
||||
|
||||
# Calculate profit metrics
|
||||
if len(self.completed_cases) > 0:
|
||||
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
|
||||
stats['profit_rate'] = profitable_count / len(self.completed_cases)
|
||||
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
|
||||
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
|
||||
|
||||
return stats
|
||||
|
||||
def force_training_session(self) -> bool:
|
||||
"""Force a training session regardless of timing constraints"""
|
||||
try:
|
||||
if self.is_training_active:
|
||||
logger.warning("Training already in progress")
|
||||
return False
|
||||
|
||||
if len(self.completed_cases) < 1:
|
||||
logger.warning("No completed cases available for training")
|
||||
return False
|
||||
|
||||
logger.info("RETROSPECTIVE: Forcing training session")
|
||||
self._start_training_session()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcing training session: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the retrospective trainer"""
|
||||
try:
|
||||
self.is_training_active = False
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.training_thread.join(timeout=10)
|
||||
logger.info("RetrospectiveTrainer stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
|
||||
|
||||
|
||||
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
|
||||
"""Factory function to create a RetrospectiveTrainer instance"""
|
||||
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)
|
@ -1368,7 +1368,7 @@ class CleanTradingDashboard:
|
||||
result = self.trading_executor.execute_trade(symbol, action, size)
|
||||
if result:
|
||||
signal['executed'] = True
|
||||
logger.info(f"✅ EXECUTED {action} signal: {symbol} @ ${signal.get('price', 0):.2f} "
|
||||
logger.info(f"EXECUTED {action} signal: {symbol} @ ${signal.get('price', 0):.2f} "
|
||||
f"(conf: {signal['confidence']:.2f}, size: {size}) - {execution_reason}")
|
||||
|
||||
# Create trade record for tracking
|
||||
@ -1436,7 +1436,7 @@ class CleanTradingDashboard:
|
||||
else:
|
||||
signal['blocked'] = True
|
||||
signal['block_reason'] = "Trading executor failed"
|
||||
logger.warning(f"❌ BLOCKED {action} signal: executor failed")
|
||||
logger.warning(f"BLOCKED {action} signal: executor failed")
|
||||
else:
|
||||
signal['blocked'] = True
|
||||
signal['block_reason'] = "No trading executor or invalid action"
|
||||
@ -1444,7 +1444,7 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
signal['blocked'] = True
|
||||
signal['block_reason'] = str(e)
|
||||
logger.error(f"❌ EXECUTION ERROR for {signal.get('action', 'UNKNOWN')}: {e}")
|
||||
logger.error(f"EXECUTION ERROR for {signal.get('action', 'UNKNOWN')}: {e}")
|
||||
else:
|
||||
# Determine which threshold was not met
|
||||
if action == 'BUY':
|
||||
|
Reference in New Issue
Block a user