442 lines
19 KiB
Python
442 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training Integration - Handles cold start training and model learning integration
|
|
|
|
Manages:
|
|
- Cold start training triggers from trade outcomes
|
|
- Reward calculation based on P&L
|
|
- Integration with DQN, CNN, and COB RL models
|
|
- Training session management
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Dict, List, Any, Optional
|
|
import numpy as np
|
|
from utils.reward_calculator import RewardCalculator
|
|
import threading
|
|
import time
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TrainingIntegration:
|
|
"""Manages training integration for cold start learning"""
|
|
|
|
def __init__(self, orchestrator=None):
|
|
self.orchestrator = orchestrator
|
|
self.reward_calculator = RewardCalculator()
|
|
self.training_sessions = {}
|
|
self.min_confidence_threshold = 0.15 # Lowered from 0.3 for more aggressive training
|
|
self.training_active = False
|
|
self.trainer_thread = None
|
|
self.stop_event = threading.Event()
|
|
self.training_lock = threading.Lock()
|
|
self.last_training_time = 0.0 if orchestrator is None else time.time()
|
|
self.training_interval = 300 # 5 minutes between training sessions
|
|
self.min_data_points = 100 # Minimum data points required to trigger training
|
|
|
|
logger.info("TrainingIntegration initialized")
|
|
|
|
def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool:
|
|
"""Trigger cold start training when trades close with known outcomes"""
|
|
try:
|
|
if not trade_record.get('model_inputs_at_entry'):
|
|
logger.warning("No model inputs captured for training - skipping")
|
|
return False
|
|
|
|
pnl = trade_record.get('pnl', 0)
|
|
confidence = trade_record.get('confidence', 0)
|
|
|
|
logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}")
|
|
|
|
# Calculate training reward based on P&L and confidence
|
|
reward = self._calculate_training_reward(pnl, confidence)
|
|
|
|
# Train DQN on trade outcome
|
|
dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward)
|
|
|
|
# Train CNN if available (placeholder for now)
|
|
cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward)
|
|
|
|
# Train COB RL if available (placeholder for now)
|
|
cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward)
|
|
|
|
# Log training results
|
|
training_success = any([dqn_success, cnn_success, cob_success])
|
|
if training_success:
|
|
logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}")
|
|
else:
|
|
logger.warning("Cold start training failed for all models")
|
|
|
|
return training_success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in cold start training: {e}")
|
|
return False
|
|
|
|
def _calculate_training_reward(self, pnl: float, confidence: float) -> float:
|
|
"""Calculate training reward based on P&L and confidence"""
|
|
try:
|
|
# Base reward is proportional to P&L
|
|
base_reward = pnl
|
|
|
|
# Adjust for confidence - penalize high confidence wrong predictions more
|
|
if pnl < 0 and confidence > 0.7:
|
|
# High confidence loss - significant negative reward
|
|
confidence_adjustment = -confidence * 2
|
|
elif pnl > 0 and confidence > 0.7:
|
|
# High confidence gain - boost reward
|
|
confidence_adjustment = confidence * 1.5
|
|
else:
|
|
# Low confidence - minimal adjustment
|
|
confidence_adjustment = 0
|
|
|
|
final_reward = base_reward + confidence_adjustment
|
|
|
|
# Normalize to [-1, 1] range for training stability
|
|
normalized_reward = np.tanh(final_reward / 10.0)
|
|
|
|
logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
|
|
|
return float(normalized_reward)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating training reward: {e}")
|
|
return 0.0
|
|
|
|
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
|
"""Train DQN agent on trade outcome"""
|
|
try:
|
|
if not self.orchestrator:
|
|
logger.warning("No orchestrator available for DQN training")
|
|
return False
|
|
|
|
# Get DQN agent
|
|
if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent:
|
|
logger.warning("DQN agent not available for training")
|
|
return False
|
|
|
|
# Extract DQN state from model inputs
|
|
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
|
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector')
|
|
|
|
if not dqn_state:
|
|
logger.warning("No DQN state available for training")
|
|
return False
|
|
|
|
# Convert action to DQN action index
|
|
action = trade_record.get('side', 'HOLD').upper()
|
|
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
|
action_idx = action_map.get(action, 2)
|
|
|
|
# Create next state (simplified - could be current market state)
|
|
next_state = dqn_state # Placeholder - should be state after trade
|
|
|
|
# Store experience in DQN memory
|
|
dqn_agent = self.orchestrator.dqn_agent
|
|
if hasattr(dqn_agent, 'remember'):
|
|
dqn_agent.remember(
|
|
state=np.array(dqn_state),
|
|
action=action_idx,
|
|
reward=reward,
|
|
next_state=np.array(next_state),
|
|
done=True # Trade is complete
|
|
)
|
|
|
|
# Trigger training if enough experiences
|
|
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
|
dqn_agent.replay()
|
|
logger.info("DQN training step completed")
|
|
|
|
return True
|
|
else:
|
|
logger.warning("DQN agent doesn't support experience storage")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training DQN on trade outcome: {e}")
|
|
return False
|
|
|
|
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
|
"""Train CNN on trade outcome with real implementation"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return False
|
|
|
|
# Check if CNN is available
|
|
cnn_model = None
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
cnn_model = self.orchestrator.cnn_model
|
|
elif hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn:
|
|
cnn_model = self.orchestrator.williams_cnn
|
|
|
|
if not cnn_model:
|
|
logger.debug("CNN not available for training")
|
|
return False
|
|
|
|
# Get CNN features from model inputs
|
|
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
|
cnn_features = model_inputs.get('cnn_features')
|
|
|
|
if not cnn_features:
|
|
logger.debug("No CNN features available for training")
|
|
return False
|
|
|
|
# Determine target based on trade outcome
|
|
pnl = trade_record.get('pnl', 0)
|
|
action = trade_record.get('side', 'HOLD').upper()
|
|
|
|
# Create target based on trade success
|
|
if pnl > 0:
|
|
if action == 'BUY':
|
|
target = 0 # Successful BUY
|
|
elif action == 'SELL':
|
|
target = 1 # Successful SELL
|
|
else:
|
|
target = 2 # HOLD
|
|
else:
|
|
# For unsuccessful trades, learn the opposite
|
|
if action == 'BUY':
|
|
target = 1 # Should have been SELL
|
|
elif action == 'SELL':
|
|
target = 0 # Should have been BUY
|
|
else:
|
|
target = 2 # HOLD
|
|
|
|
# Initialize model attributes if needed
|
|
if not hasattr(cnn_model, 'optimizer'):
|
|
import torch
|
|
cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
|
|
|
|
# Perform actual CNN training
|
|
try:
|
|
import torch
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Prepare features
|
|
if isinstance(cnn_features, list):
|
|
features = np.array(cnn_features, dtype=np.float32)
|
|
else:
|
|
features = np.array(cnn_features, dtype=np.float32)
|
|
|
|
# Ensure features are the right size
|
|
if len(features) < 50:
|
|
# Pad with zeros
|
|
padded_features = np.zeros(50)
|
|
padded_features[:len(features)] = features
|
|
features = padded_features
|
|
elif len(features) > 50:
|
|
# Truncate
|
|
features = features[:50]
|
|
|
|
# Create tensors
|
|
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
|
|
target_tensor = torch.LongTensor([target]).to(device)
|
|
|
|
# Training step
|
|
cnn_model.train()
|
|
cnn_model.optimizer.zero_grad()
|
|
|
|
outputs = cnn_model(features_tensor)
|
|
|
|
# Handle different output formats
|
|
if isinstance(outputs, dict):
|
|
if 'main_output' in outputs:
|
|
logits = outputs['main_output']
|
|
elif 'action_logits' in outputs:
|
|
logits = outputs['action_logits']
|
|
else:
|
|
logits = list(outputs.values())[0]
|
|
else:
|
|
logits = outputs
|
|
|
|
# Calculate loss with reward weighting
|
|
loss_fn = torch.nn.CrossEntropyLoss()
|
|
loss = loss_fn(logits, target_tensor)
|
|
|
|
# Weight loss by reward magnitude
|
|
weighted_loss = loss * abs(reward)
|
|
|
|
# Backward pass
|
|
weighted_loss.backward()
|
|
cnn_model.optimizer.step()
|
|
|
|
logger.info(f"CNN trained on trade outcome: P&L=${pnl:.2f}, loss={loss.item():.4f}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training step: {e}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training: {e}")
|
|
return False
|
|
|
|
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
|
"""Train COB RL on trade outcome with real implementation"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return False
|
|
|
|
# Check if COB RL agent is available
|
|
cob_rl_agent = None
|
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
cob_rl_agent = self.orchestrator.rl_agent
|
|
elif hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
|
cob_rl_agent = self.orchestrator.cob_rl_agent
|
|
|
|
if not cob_rl_agent:
|
|
logger.debug("COB RL agent not available for training")
|
|
return False
|
|
|
|
# Get COB features from model inputs
|
|
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
|
cob_features = model_inputs.get('cob_features')
|
|
|
|
if not cob_features:
|
|
logger.debug("No COB features available for training")
|
|
return False
|
|
|
|
# Create state from COB features
|
|
if isinstance(cob_features, list):
|
|
state_features = np.array(cob_features, dtype=np.float32)
|
|
else:
|
|
state_features = np.array(cob_features, dtype=np.float32)
|
|
|
|
# Pad or truncate to expected size
|
|
if hasattr(cob_rl_agent, 'state_shape'):
|
|
expected_size = cob_rl_agent.state_shape if isinstance(cob_rl_agent.state_shape, int) else cob_rl_agent.state_shape[0]
|
|
else:
|
|
expected_size = 100 # Default size
|
|
|
|
if len(state_features) < expected_size:
|
|
# Pad with zeros
|
|
padded_features = np.zeros(expected_size)
|
|
padded_features[:len(state_features)] = state_features
|
|
state_features = padded_features
|
|
elif len(state_features) > expected_size:
|
|
# Truncate
|
|
state_features = state_features[:expected_size]
|
|
|
|
state = np.array(state_features, dtype=np.float32)
|
|
|
|
# Determine action from trade record
|
|
action_str = trade_record.get('side', 'HOLD').upper()
|
|
if action_str == 'BUY':
|
|
action = 0
|
|
elif action_str == 'SELL':
|
|
action = 1
|
|
else:
|
|
action = 2 # HOLD
|
|
|
|
# Create next state (similar to current state for simplicity)
|
|
next_state = state.copy()
|
|
|
|
# Use PnL as reward
|
|
pnl = trade_record.get('pnl', 0)
|
|
actual_reward = float(pnl * 100) # Scale reward
|
|
|
|
# Store experience in agent memory
|
|
if hasattr(cob_rl_agent, 'remember'):
|
|
cob_rl_agent.remember(state, action, actual_reward, next_state, done=True)
|
|
elif hasattr(cob_rl_agent, 'store_experience'):
|
|
cob_rl_agent.store_experience(state, action, actual_reward, next_state, done=True)
|
|
|
|
# Perform training step if agent has replay method
|
|
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
|
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
|
loss = cob_rl_agent.replay()
|
|
if loss is not None:
|
|
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
|
|
return True
|
|
|
|
logger.debug(f"COB RL experience stored: P&L=${pnl:.2f}, reward={actual_reward:.2f}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in COB RL training: {e}")
|
|
return False
|
|
|
|
def get_training_status(self) -> Dict[str, Any]:
|
|
"""Get current training status"""
|
|
try:
|
|
status = {
|
|
'active': self.training_active,
|
|
'last_training_time': self.last_training_time,
|
|
'training_sessions': self.training_sessions if self.training_sessions else {}
|
|
}
|
|
return status
|
|
except Exception as e:
|
|
logger.error(f"Error getting training status: {e}")
|
|
return {}
|
|
|
|
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
|
|
"""Start a new training session"""
|
|
try:
|
|
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
self.training_sessions[session_id] = {
|
|
'name': session_name,
|
|
'start_time': datetime.now(),
|
|
'config': config if config else {},
|
|
'trades_processed': 0,
|
|
'training_attempts': 0,
|
|
'successful_trainings': 0
|
|
}
|
|
logger.info(f"Started training session: {session_id}")
|
|
return session_id
|
|
except Exception as e:
|
|
logger.error(f"Error starting training session: {e}")
|
|
return ""
|
|
|
|
def end_training_session(self, session_id: str) -> Dict[str, Any]:
|
|
"""End a training session and return summary"""
|
|
try:
|
|
if session_id not in self.training_sessions:
|
|
logger.warning(f"Training session not found: {session_id}")
|
|
return {}
|
|
|
|
session_data = self.training_sessions[session_id]
|
|
session_data['end_time'] = datetime.now().isoformat()
|
|
|
|
# Calculate session duration
|
|
start_time = datetime.fromisoformat(session_data['start_time'])
|
|
end_time = datetime.fromisoformat(session_data['end_time'])
|
|
duration = (end_time - start_time).total_seconds()
|
|
session_data['duration_seconds'] = duration
|
|
|
|
# Calculate success rate
|
|
total_attempts = session_data['successful_trainings'] + session_data['failed_trainings']
|
|
session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0
|
|
|
|
logger.info(f"Ended training session: {session_id}")
|
|
logger.info(f" Duration: {duration:.1f}s")
|
|
logger.info(f" Trades processed: {session_data['trades_processed']}")
|
|
logger.info(f" Success rate: {session_data['success_rate']:.2%}")
|
|
|
|
# Remove from active sessions
|
|
completed_session = self.training_sessions.pop(session_id)
|
|
|
|
return completed_session
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ending training session: {e}")
|
|
return {}
|
|
|
|
def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False):
|
|
"""Update training session statistics"""
|
|
try:
|
|
if session_id not in self.training_sessions:
|
|
return
|
|
|
|
session = self.training_sessions[session_id]
|
|
|
|
if trade_processed:
|
|
session['trades_processed'] += 1
|
|
|
|
if training_success:
|
|
session['successful_trainings'] += 1
|
|
else:
|
|
session['failed_trainings'] += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating session stats: {e}") |