train retrospectively progress (wip)

This commit is contained in:
Dobromir Popov
2025-06-25 17:22:45 +03:00
parent 29b3325581
commit 7d00a281ba
2 changed files with 456 additions and 3 deletions

View File

@ -0,0 +1,453 @@
"""
Retrospective Training System
This module implements a retrospective training system that:
1. Triggers training when trades close with known P&L outcomes
2. Uses captured model inputs from trade entry to train models
3. Optimizes for profit by learning from profitable vs unprofitable patterns
4. Supports simultaneous inference and training without weight reloading
5. Implements reinforcement learning with immediate reward feedback
"""
import logging
import threading
import time
import queue
from datetime import datetime
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
import numpy as np
from collections import deque
logger = logging.getLogger(__name__)
@dataclass
class TrainingCase:
"""Represents a completed trade case for retrospective training"""
case_id: str
symbol: str
action: str # 'BUY' or 'SELL'
entry_price: float
exit_price: float
entry_time: datetime
exit_time: datetime
pnl: float
fees: float
confidence: float
model_inputs: Dict[str, Any]
market_state: Dict[str, Any]
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
reward_signal: float # Scaled reward for RL training
leverage: float = 1.0
class RetrospectiveTrainer:
"""Retrospective training system for real-time model optimization"""
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
"""Initialize the retrospective trainer"""
self.orchestrator = orchestrator
self.config = config or {}
# Training configuration
self.batch_size = self.config.get('batch_size', 32)
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
self.profit_threshold = self.config.get('profit_threshold', 0.0)
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
self.max_training_cases = self.config.get('max_training_cases', 1000)
# Training state
self.training_queue = queue.Queue()
self.completed_cases = deque(maxlen=self.max_training_cases)
self.training_stats = {
'total_cases': 0,
'profitable_cases': 0,
'loss_cases': 0,
'breakeven_cases': 0,
'avg_profit': 0.0,
'last_training_time': datetime.now(),
'training_sessions': 0,
'model_updates': 0
}
# Threading
self.training_thread = None
self.is_training_active = False
self.training_lock = threading.Lock()
logger.info("RetrospectiveTrainer initialized")
logger.info(f"Configuration: batch_size={self.batch_size}, "
f"min_cases={self.min_cases_for_training}, "
f"training_freq={self.training_frequency}s")
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
"""Add a completed trade for retrospective training"""
try:
# Create training case from trade record
case = self._create_training_case(trade_record, model_inputs)
if case is None:
return False
# Add to completed cases
self.completed_cases.append(case)
self.training_queue.put(case)
# Update statistics
self.training_stats['total_cases'] += 1
if case.outcome_label == 1: # Profit
self.training_stats['profitable_cases'] += 1
elif case.outcome_label == 0: # Loss
self.training_stats['loss_cases'] += 1
else: # Breakeven
self.training_stats['breakeven_cases'] += 1
# Calculate running average profit
total_pnl = sum(c.pnl for c in self.completed_cases)
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
# Trigger training if we have enough cases
self._maybe_trigger_training()
return True
except Exception as e:
logger.error(f"Error adding completed trade for retrospective training: {e}")
return False
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
"""Create a training case from trade record and model inputs"""
try:
# Extract trade information
symbol = trade_record.get('symbol', 'UNKNOWN')
side = trade_record.get('side', 'UNKNOWN')
pnl = trade_record.get('pnl', 0.0)
fees = trade_record.get('fees', 0.0)
confidence = trade_record.get('confidence', 0.0)
# Calculate net P&L after fees
net_pnl = pnl - fees
# Determine outcome label and reward signal
if net_pnl > self.profit_threshold:
outcome_label = 1 # Profitable
# Scale reward by profit magnitude and confidence
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
elif net_pnl < -self.profit_threshold:
outcome_label = 0 # Loss
# Negative reward scaled by loss magnitude
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
else:
outcome_label = 2 # Breakeven
reward_signal = 0.0
# Create case ID
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
# Create training case
case = TrainingCase(
case_id=case_id,
symbol=symbol,
action=side,
entry_price=trade_record.get('entry_price', 0.0),
exit_price=trade_record.get('exit_price', 0.0),
entry_time=trade_record.get('entry_time', datetime.now()),
exit_time=trade_record.get('exit_time', datetime.now()),
pnl=net_pnl,
fees=fees,
confidence=confidence,
model_inputs=model_inputs,
market_state=model_inputs.get('market_state', {}),
outcome_label=outcome_label,
reward_signal=reward_signal,
leverage=trade_record.get('leverage', 1.0)
)
return case
except Exception as e:
logger.error(f"Error creating training case: {e}")
return None
def _maybe_trigger_training(self):
"""Check if we should trigger a training session"""
try:
# Check if we have enough cases
if len(self.completed_cases) < self.min_cases_for_training:
return
# Check if enough time has passed since last training
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
if time_since_last < self.training_frequency:
return
# Check if training thread is not already running
if self.is_training_active:
logger.debug("Training already in progress, skipping trigger")
return
# Start training in background thread
self._start_training_session()
except Exception as e:
logger.error(f"Error checking training trigger: {e}")
def _start_training_session(self):
"""Start a training session in background thread"""
try:
if self.training_thread and self.training_thread.is_alive():
logger.debug("Training thread already running")
return
self.training_thread = threading.Thread(
target=self._run_training_session,
daemon=True,
name="RetrospectiveTrainer"
)
self.training_thread.start()
logger.info("RETROSPECTIVE: Started training session")
except Exception as e:
logger.error(f"Error starting training session: {e}")
def _run_training_session(self):
"""Run a complete training session"""
try:
with self.training_lock:
self.is_training_active = True
start_time = time.time()
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
# Train models if orchestrator available
training_results = {}
if self.orchestrator:
training_results = self._train_models()
# Update statistics
self.training_stats['last_training_time'] = datetime.now()
self.training_stats['training_sessions'] += 1
self.training_stats['model_updates'] += len(training_results)
elapsed_time = time.time() - start_time
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
except Exception as e:
logger.error(f"Error in retrospective training session: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
self.is_training_active = False
def _train_models(self) -> Dict[str, Any]:
"""Train available models using retrospective data"""
results = {}
try:
# Prepare training data
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
if len(profitable_cases) == 0 and len(loss_cases) == 0:
return {'error': 'No labeled cases for training'}
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
# Train DQN agent if available
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
try:
dqn_result = self._train_dqn_retrospective()
results['dqn'] = dqn_result
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
except Exception as e:
logger.warning(f"DQN retrospective training failed: {e}")
results['dqn'] = {'error': str(e)}
# Train other models
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
try:
# Update extrema trainer with retrospective feedback
extrema_feedback = self._create_extrema_feedback()
if extrema_feedback:
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
except Exception as e:
logger.warning(f"Extrema retrospective training failed: {e}")
return results
except Exception as e:
logger.error(f"Error training models retrospectively: {e}")
return {'error': str(e)}
def _train_dqn_retrospective(self) -> Dict[str, Any]:
"""Train DQN agent using retrospective experience replay"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
return {'error': 'DQN agent not available'}
dqn_agent = self.orchestrator.rl_agent
experiences_added = 0
# Add retrospective experiences to DQN replay buffer
for case in self.completed_cases:
try:
# Extract state from model inputs
state = self._extract_state_vector(case.model_inputs)
if state is None:
continue
# Action mapping: BUY=0, SELL=1
action = 0 if case.action == 'BUY' else 1
# Use reward signal as immediate reward
reward = case.reward_signal
# For retrospective training, next_state is None (terminal)
next_state = np.zeros_like(state) # Terminal state
done = True
# Add experience to DQN replay buffer
if hasattr(dqn_agent, 'add_experience'):
dqn_agent.add_experience(state, action, reward, next_state, done)
experiences_added += 1
except Exception as e:
logger.debug(f"Error adding DQN experience: {e}")
continue
# Train DQN if we have enough experiences
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
try:
# Perform multiple training steps on retrospective data
training_steps = min(10, experiences_added // 4) # Conservative training
for _ in range(training_steps):
loss = dqn_agent.train()
if loss is None:
break
return {
'experiences_added': experiences_added,
'training_steps': training_steps,
'method': 'retrospective_experience_replay'
}
except Exception as e:
logger.warning(f"DQN training step failed: {e}")
return {'experiences_added': experiences_added, 'training_error': str(e)}
return {'experiences_added': experiences_added, 'training_steps': 0}
except Exception as e:
logger.error(f"Error in DQN retrospective training: {e}")
return {'error': str(e)}
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
"""Extract state vector for DQN training from model inputs"""
try:
# Try to get pre-built RL state
if 'dqn_state' in model_inputs:
state = model_inputs['dqn_state']
if isinstance(state, dict) and 'state_vector' in state:
return np.array(state['state_vector'])
# Build state from market features
market_state = model_inputs.get('market_state', {})
features = []
# Price features
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
features.append(market_state.get(key, 0.0))
# Volume features
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
features.append(market_state.get(key, 0.0))
# Technical indicators
indicators = model_inputs.get('technical_indicators', {})
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
features.append(indicators.get(key, 0.0))
if len(features) < 5: # Minimum required features
return None
return np.array(features, dtype=np.float32)
except Exception as e:
logger.debug(f"Error extracting state vector: {e}")
return None
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
"""Create feedback data for extrema trainer"""
feedback = []
try:
for case in self.completed_cases:
if case.outcome_label in [0, 1]: # Only profit/loss cases
feedback_item = {
'symbol': case.symbol,
'action': case.action,
'entry_price': case.entry_price,
'exit_price': case.exit_price,
'was_profitable': case.outcome_label == 1,
'reward_signal': case.reward_signal,
'market_state': case.market_state
}
feedback.append(feedback_item)
return feedback
except Exception as e:
logger.error(f"Error creating extrema feedback: {e}")
return []
def get_training_stats(self) -> Dict[str, Any]:
"""Get current training statistics"""
stats = self.training_stats.copy()
stats['total_cases_in_memory'] = len(self.completed_cases)
stats['training_queue_size'] = self.training_queue.qsize()
stats['is_training_active'] = self.is_training_active
# Calculate profit metrics
if len(self.completed_cases) > 0:
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
stats['profit_rate'] = profitable_count / len(self.completed_cases)
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
return stats
def force_training_session(self) -> bool:
"""Force a training session regardless of timing constraints"""
try:
if self.is_training_active:
logger.warning("Training already in progress")
return False
if len(self.completed_cases) < 1:
logger.warning("No completed cases available for training")
return False
logger.info("RETROSPECTIVE: Forcing training session")
self._start_training_session()
return True
except Exception as e:
logger.error(f"Error forcing training session: {e}")
return False
def stop(self):
"""Stop the retrospective trainer"""
try:
self.is_training_active = False
if self.training_thread and self.training_thread.is_alive():
self.training_thread.join(timeout=10)
logger.info("RetrospectiveTrainer stopped")
except Exception as e:
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
"""Factory function to create a RetrospectiveTrainer instance"""
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)

View File

@ -1368,7 +1368,7 @@ class CleanTradingDashboard:
result = self.trading_executor.execute_trade(symbol, action, size) result = self.trading_executor.execute_trade(symbol, action, size)
if result: if result:
signal['executed'] = True signal['executed'] = True
logger.info(f"EXECUTED {action} signal: {symbol} @ ${signal.get('price', 0):.2f} " logger.info(f"EXECUTED {action} signal: {symbol} @ ${signal.get('price', 0):.2f} "
f"(conf: {signal['confidence']:.2f}, size: {size}) - {execution_reason}") f"(conf: {signal['confidence']:.2f}, size: {size}) - {execution_reason}")
# Create trade record for tracking # Create trade record for tracking
@ -1436,7 +1436,7 @@ class CleanTradingDashboard:
else: else:
signal['blocked'] = True signal['blocked'] = True
signal['block_reason'] = "Trading executor failed" signal['block_reason'] = "Trading executor failed"
logger.warning(f"BLOCKED {action} signal: executor failed") logger.warning(f"BLOCKED {action} signal: executor failed")
else: else:
signal['blocked'] = True signal['blocked'] = True
signal['block_reason'] = "No trading executor or invalid action" signal['block_reason'] = "No trading executor or invalid action"
@ -1444,7 +1444,7 @@ class CleanTradingDashboard:
except Exception as e: except Exception as e:
signal['blocked'] = True signal['blocked'] = True
signal['block_reason'] = str(e) signal['block_reason'] = str(e)
logger.error(f"EXECUTION ERROR for {signal.get('action', 'UNKNOWN')}: {e}") logger.error(f"EXECUTION ERROR for {signal.get('action', 'UNKNOWN')}: {e}")
else: else:
# Determine which threshold was not met # Determine which threshold was not met
if action == 'BUY': if action == 'BUY':