Files
gogo2/core/training_integration.py
2025-07-07 01:07:48 +03:00

442 lines
19 KiB
Python

#!/usr/bin/env python3
"""
Training Integration - Handles cold start training and model learning integration
Manages:
- Cold start training triggers from trade outcomes
- Reward calculation based on P&L
- Integration with DQN, CNN, and COB RL models
- Training session management
"""
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
import numpy as np
from utils.reward_calculator import RewardCalculator
import threading
import time
logger = logging.getLogger(__name__)
class TrainingIntegration:
"""Manages training integration for cold start learning"""
def __init__(self, orchestrator=None):
self.orchestrator = orchestrator
self.reward_calculator = RewardCalculator()
self.training_sessions = {}
self.min_confidence_threshold = 0.15 # Lowered from 0.3 for more aggressive training
self.training_active = False
self.trainer_thread = None
self.stop_event = threading.Event()
self.training_lock = threading.Lock()
self.last_training_time = 0.0 if orchestrator is None else time.time()
self.training_interval = 300 # 5 minutes between training sessions
self.min_data_points = 100 # Minimum data points required to trigger training
logger.info("TrainingIntegration initialized")
def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool:
"""Trigger cold start training when trades close with known outcomes"""
try:
if not trade_record.get('model_inputs_at_entry'):
logger.warning("No model inputs captured for training - skipping")
return False
pnl = trade_record.get('pnl', 0)
confidence = trade_record.get('confidence', 0)
logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}")
# Calculate training reward based on P&L and confidence
reward = self._calculate_training_reward(pnl, confidence)
# Train DQN on trade outcome
dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward)
# Train CNN if available (placeholder for now)
cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward)
# Train COB RL if available (placeholder for now)
cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward)
# Log training results
training_success = any([dqn_success, cnn_success, cob_success])
if training_success:
logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}")
else:
logger.warning("Cold start training failed for all models")
return training_success
except Exception as e:
logger.error(f"Error in cold start training: {e}")
return False
def _calculate_training_reward(self, pnl: float, confidence: float) -> float:
"""Calculate training reward based on P&L and confidence"""
try:
# Base reward is proportional to P&L
base_reward = pnl
# Adjust for confidence - penalize high confidence wrong predictions more
if pnl < 0 and confidence > 0.7:
# High confidence loss - significant negative reward
confidence_adjustment = -confidence * 2
elif pnl > 0 and confidence > 0.7:
# High confidence gain - boost reward
confidence_adjustment = confidence * 1.5
else:
# Low confidence - minimal adjustment
confidence_adjustment = 0
final_reward = base_reward + confidence_adjustment
# Normalize to [-1, 1] range for training stability
normalized_reward = np.tanh(final_reward / 10.0)
logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
return float(normalized_reward)
except Exception as e:
logger.error(f"Error calculating training reward: {e}")
return 0.0
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train DQN agent on trade outcome"""
try:
if not self.orchestrator:
logger.warning("No orchestrator available for DQN training")
return False
# Get DQN agent
if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent:
logger.warning("DQN agent not available for training")
return False
# Extract DQN state from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector')
if not dqn_state:
logger.warning("No DQN state available for training")
return False
# Convert action to DQN action index
action = trade_record.get('side', 'HOLD').upper()
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
action_idx = action_map.get(action, 2)
# Create next state (simplified - could be current market state)
next_state = dqn_state # Placeholder - should be state after trade
# Store experience in DQN memory
dqn_agent = self.orchestrator.dqn_agent
if hasattr(dqn_agent, 'store_experience'):
dqn_agent.store_experience(
state=np.array(dqn_state),
action=action_idx,
reward=reward,
next_state=np.array(next_state),
done=True # Trade is complete
)
# Trigger training if enough experiences
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
dqn_agent.replay(batch_size=32)
logger.info("DQN training step completed")
return True
else:
logger.warning("DQN agent doesn't support experience storage")
return False
except Exception as e:
logger.error(f"Error training DQN on trade outcome: {e}")
return False
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train CNN on trade outcome with real implementation"""
try:
if not self.orchestrator:
return False
# Check if CNN is available
cnn_model = None
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
cnn_model = self.orchestrator.cnn_model
elif hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn:
cnn_model = self.orchestrator.williams_cnn
if not cnn_model:
logger.debug("CNN not available for training")
return False
# Get CNN features from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
cnn_features = model_inputs.get('cnn_features')
if not cnn_features:
logger.debug("No CNN features available for training")
return False
# Determine target based on trade outcome
pnl = trade_record.get('pnl', 0)
action = trade_record.get('side', 'HOLD').upper()
# Create target based on trade success
if pnl > 0:
if action == 'BUY':
target = 0 # Successful BUY
elif action == 'SELL':
target = 1 # Successful SELL
else:
target = 2 # HOLD
else:
# For unsuccessful trades, learn the opposite
if action == 'BUY':
target = 1 # Should have been SELL
elif action == 'SELL':
target = 0 # Should have been BUY
else:
target = 2 # HOLD
# Initialize model attributes if needed
if not hasattr(cnn_model, 'optimizer'):
import torch
cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
# Perform actual CNN training
try:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Prepare features
if isinstance(cnn_features, list):
features = np.array(cnn_features, dtype=np.float32)
else:
features = np.array(cnn_features, dtype=np.float32)
# Ensure features are the right size
if len(features) < 50:
# Pad with zeros
padded_features = np.zeros(50)
padded_features[:len(features)] = features
features = padded_features
elif len(features) > 50:
# Truncate
features = features[:50]
# Create tensors
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
target_tensor = torch.LongTensor([target]).to(device)
# Training step
cnn_model.train()
cnn_model.optimizer.zero_grad()
outputs = cnn_model(features_tensor)
# Handle different output formats
if isinstance(outputs, dict):
if 'main_output' in outputs:
logits = outputs['main_output']
elif 'action_logits' in outputs:
logits = outputs['action_logits']
else:
logits = list(outputs.values())[0]
else:
logits = outputs
# Calculate loss with reward weighting
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(logits, target_tensor)
# Weight loss by reward magnitude
weighted_loss = loss * abs(reward)
# Backward pass
weighted_loss.backward()
cnn_model.optimizer.step()
logger.info(f"CNN trained on trade outcome: P&L=${pnl:.2f}, loss={loss.item():.4f}")
return True
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
return False
except Exception as e:
logger.error(f"Error in CNN training: {e}")
return False
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train COB RL on trade outcome with real implementation"""
try:
if not self.orchestrator:
return False
# Check if COB RL agent is available
cob_rl_agent = None
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
cob_rl_agent = self.orchestrator.rl_agent
elif hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
cob_rl_agent = self.orchestrator.cob_rl_agent
if not cob_rl_agent:
logger.debug("COB RL agent not available for training")
return False
# Get COB features from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
cob_features = model_inputs.get('cob_features')
if not cob_features:
logger.debug("No COB features available for training")
return False
# Create state from COB features
if isinstance(cob_features, list):
state_features = np.array(cob_features, dtype=np.float32)
else:
state_features = np.array(cob_features, dtype=np.float32)
# Pad or truncate to expected size
if hasattr(cob_rl_agent, 'state_shape'):
expected_size = cob_rl_agent.state_shape if isinstance(cob_rl_agent.state_shape, int) else cob_rl_agent.state_shape[0]
else:
expected_size = 100 # Default size
if len(state_features) < expected_size:
# Pad with zeros
padded_features = np.zeros(expected_size)
padded_features[:len(state_features)] = state_features
state_features = padded_features
elif len(state_features) > expected_size:
# Truncate
state_features = state_features[:expected_size]
state = np.array(state_features, dtype=np.float32)
# Determine action from trade record
action_str = trade_record.get('side', 'HOLD').upper()
if action_str == 'BUY':
action = 0
elif action_str == 'SELL':
action = 1
else:
action = 2 # HOLD
# Create next state (similar to current state for simplicity)
next_state = state.copy()
# Use PnL as reward
pnl = trade_record.get('pnl', 0)
actual_reward = float(pnl * 100) # Scale reward
# Store experience in agent memory
if hasattr(cob_rl_agent, 'remember'):
cob_rl_agent.remember(state, action, actual_reward, next_state, done=True)
elif hasattr(cob_rl_agent, 'store_experience'):
cob_rl_agent.store_experience(state, action, actual_reward, next_state, done=True)
# Perform training step if agent has replay method
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
if len(cob_rl_agent.memory) > 32: # Enough samples to train
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
if loss is not None:
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
return True
logger.debug(f"COB RL experience stored: P&L=${pnl:.2f}, reward={actual_reward:.2f}")
return True
except Exception as e:
logger.error(f"Error in COB RL training: {e}")
return False
def get_training_status(self) -> Dict[str, Any]:
"""Get current training status"""
try:
status = {
'active': self.training_active,
'last_training_time': self.last_training_time,
'training_sessions': self.training_sessions if self.training_sessions else {}
}
return status
except Exception as e:
logger.error(f"Error getting training status: {e}")
return {}
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
"""Start a new training session"""
try:
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.training_sessions[session_id] = {
'name': session_name,
'start_time': datetime.now(),
'config': config if config else {},
'trades_processed': 0,
'training_attempts': 0,
'successful_trainings': 0
}
logger.info(f"Started training session: {session_id}")
return session_id
except Exception as e:
logger.error(f"Error starting training session: {e}")
return ""
def end_training_session(self, session_id: str) -> Dict[str, Any]:
"""End a training session and return summary"""
try:
if session_id not in self.training_sessions:
logger.warning(f"Training session not found: {session_id}")
return {}
session_data = self.training_sessions[session_id]
session_data['end_time'] = datetime.now().isoformat()
# Calculate session duration
start_time = datetime.fromisoformat(session_data['start_time'])
end_time = datetime.fromisoformat(session_data['end_time'])
duration = (end_time - start_time).total_seconds()
session_data['duration_seconds'] = duration
# Calculate success rate
total_attempts = session_data['successful_trainings'] + session_data['failed_trainings']
session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0
logger.info(f"Ended training session: {session_id}")
logger.info(f" Duration: {duration:.1f}s")
logger.info(f" Trades processed: {session_data['trades_processed']}")
logger.info(f" Success rate: {session_data['success_rate']:.2%}")
# Remove from active sessions
completed_session = self.training_sessions.pop(session_id)
return completed_session
except Exception as e:
logger.error(f"Error ending training session: {e}")
return {}
def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False):
"""Update training session statistics"""
try:
if session_id not in self.training_sessions:
return
session = self.training_sessions[session_id]
if trade_processed:
session['trades_processed'] += 1
if training_success:
session['successful_trainings'] += 1
else:
session['failed_trainings'] += 1
except Exception as e:
logger.error(f"Error updating session stats: {e}")