304 lines
13 KiB
Python
304 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training Integration - Handles cold start training and model learning integration
|
|
|
|
Manages:
|
|
- Cold start training triggers from trade outcomes
|
|
- Reward calculation based on P&L
|
|
- Integration with DQN, CNN, and COB RL models
|
|
- Training session management
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Dict, List, Any, Optional
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TrainingIntegration:
|
|
"""Manages training integration for cold start learning"""
|
|
|
|
def __init__(self, orchestrator=None):
|
|
self.orchestrator = orchestrator
|
|
self.training_sessions = {}
|
|
self.min_confidence_threshold = 0.3
|
|
|
|
logger.info("TrainingIntegration initialized")
|
|
|
|
def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool:
|
|
"""Trigger cold start training when trades close with known outcomes"""
|
|
try:
|
|
if not trade_record.get('model_inputs_at_entry'):
|
|
logger.warning("No model inputs captured for training - skipping")
|
|
return False
|
|
|
|
pnl = trade_record.get('pnl', 0)
|
|
confidence = trade_record.get('confidence', 0)
|
|
|
|
logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}")
|
|
|
|
# Calculate training reward based on P&L and confidence
|
|
reward = self._calculate_training_reward(pnl, confidence)
|
|
|
|
# Train DQN on trade outcome
|
|
dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward)
|
|
|
|
# Train CNN if available (placeholder for now)
|
|
cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward)
|
|
|
|
# Train COB RL if available (placeholder for now)
|
|
cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward)
|
|
|
|
# Log training results
|
|
training_success = any([dqn_success, cnn_success, cob_success])
|
|
if training_success:
|
|
logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}")
|
|
else:
|
|
logger.warning("Cold start training failed for all models")
|
|
|
|
return training_success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in cold start training: {e}")
|
|
return False
|
|
|
|
def _calculate_training_reward(self, pnl: float, confidence: float) -> float:
|
|
"""Calculate training reward based on P&L and confidence"""
|
|
try:
|
|
# Base reward is proportional to P&L
|
|
base_reward = pnl
|
|
|
|
# Adjust for confidence - penalize high confidence wrong predictions more
|
|
if pnl < 0 and confidence > 0.7:
|
|
# High confidence loss - significant negative reward
|
|
confidence_adjustment = -confidence * 2
|
|
elif pnl > 0 and confidence > 0.7:
|
|
# High confidence gain - boost reward
|
|
confidence_adjustment = confidence * 1.5
|
|
else:
|
|
# Low confidence - minimal adjustment
|
|
confidence_adjustment = 0
|
|
|
|
final_reward = base_reward + confidence_adjustment
|
|
|
|
# Normalize to [-1, 1] range for training stability
|
|
normalized_reward = np.tanh(final_reward / 10.0)
|
|
|
|
logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
|
|
|
return float(normalized_reward)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating training reward: {e}")
|
|
return 0.0
|
|
|
|
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
|
"""Train DQN agent on trade outcome"""
|
|
try:
|
|
if not self.orchestrator:
|
|
logger.warning("No orchestrator available for DQN training")
|
|
return False
|
|
|
|
# Get DQN agent
|
|
if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent:
|
|
logger.warning("DQN agent not available for training")
|
|
return False
|
|
|
|
# Extract DQN state from model inputs
|
|
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
|
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector')
|
|
|
|
if not dqn_state:
|
|
logger.warning("No DQN state available for training")
|
|
return False
|
|
|
|
# Convert action to DQN action index
|
|
action = trade_record.get('side', 'HOLD').upper()
|
|
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
|
action_idx = action_map.get(action, 2)
|
|
|
|
# Create next state (simplified - could be current market state)
|
|
next_state = dqn_state # Placeholder - should be state after trade
|
|
|
|
# Store experience in DQN memory
|
|
dqn_agent = self.orchestrator.dqn_agent
|
|
if hasattr(dqn_agent, 'store_experience'):
|
|
dqn_agent.store_experience(
|
|
state=np.array(dqn_state),
|
|
action=action_idx,
|
|
reward=reward,
|
|
next_state=np.array(next_state),
|
|
done=True # Trade is complete
|
|
)
|
|
|
|
# Trigger training if enough experiences
|
|
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
|
dqn_agent.replay(batch_size=32)
|
|
logger.info("DQN training step completed")
|
|
|
|
return True
|
|
else:
|
|
logger.warning("DQN agent doesn't support experience storage")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training DQN on trade outcome: {e}")
|
|
return False
|
|
|
|
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
|
"""Train CNN on trade outcome (placeholder)"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return False
|
|
|
|
# Check if CNN is available
|
|
if not hasattr(self.orchestrator, 'williams_cnn') or not self.orchestrator.williams_cnn:
|
|
logger.debug("CNN not available for training")
|
|
return False
|
|
|
|
# Get CNN features from model inputs
|
|
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
|
cnn_features = model_inputs.get('cnn_features')
|
|
cnn_predictions = model_inputs.get('cnn_predictions')
|
|
|
|
if not cnn_features or not cnn_predictions:
|
|
logger.debug("No CNN features available for training")
|
|
return False
|
|
|
|
# CNN training would go here - requires more specific implementation
|
|
# For now, just log that we could train CNN
|
|
logger.debug(f"CNN training opportunity: features={len(cnn_features)}, predictions={len(cnn_predictions)}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in CNN training: {e}")
|
|
return False
|
|
|
|
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
|
"""Train COB RL on trade outcome (placeholder)"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return False
|
|
|
|
# Check if COB integration is available
|
|
if not hasattr(self.orchestrator, 'cob_integration') or not self.orchestrator.cob_integration:
|
|
logger.debug("COB integration not available for training")
|
|
return False
|
|
|
|
# Get COB features from model inputs
|
|
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
|
cob_features = model_inputs.get('cob_features')
|
|
|
|
if not cob_features:
|
|
logger.debug("No COB features available for training")
|
|
return False
|
|
|
|
# COB RL training would go here - requires more specific implementation
|
|
# For now, just log that we could train COB RL
|
|
logger.debug(f"COB RL training opportunity: features={len(cob_features)}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in COB RL training: {e}")
|
|
return False
|
|
|
|
def get_training_status(self) -> Dict[str, Any]:
|
|
"""Get current training integration status"""
|
|
try:
|
|
status = {
|
|
'orchestrator_available': self.orchestrator is not None,
|
|
'training_sessions': len(self.training_sessions),
|
|
'last_update': datetime.now().isoformat()
|
|
}
|
|
|
|
if self.orchestrator:
|
|
status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None
|
|
status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None
|
|
status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None
|
|
|
|
return status
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training status: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
|
|
"""Start a new training session"""
|
|
try:
|
|
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
|
|
session_data = {
|
|
'session_id': session_id,
|
|
'session_name': session_name,
|
|
'start_time': datetime.now().isoformat(),
|
|
'config': config or {},
|
|
'trades_processed': 0,
|
|
'successful_trainings': 0,
|
|
'failed_trainings': 0
|
|
}
|
|
|
|
self.training_sessions[session_id] = session_data
|
|
|
|
logger.info(f"Started training session: {session_id}")
|
|
|
|
return session_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting training session: {e}")
|
|
return ""
|
|
|
|
def end_training_session(self, session_id: str) -> Dict[str, Any]:
|
|
"""End a training session and return summary"""
|
|
try:
|
|
if session_id not in self.training_sessions:
|
|
logger.warning(f"Training session not found: {session_id}")
|
|
return {}
|
|
|
|
session_data = self.training_sessions[session_id]
|
|
session_data['end_time'] = datetime.now().isoformat()
|
|
|
|
# Calculate session duration
|
|
start_time = datetime.fromisoformat(session_data['start_time'])
|
|
end_time = datetime.fromisoformat(session_data['end_time'])
|
|
duration = (end_time - start_time).total_seconds()
|
|
session_data['duration_seconds'] = duration
|
|
|
|
# Calculate success rate
|
|
total_attempts = session_data['successful_trainings'] + session_data['failed_trainings']
|
|
session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0
|
|
|
|
logger.info(f"Ended training session: {session_id}")
|
|
logger.info(f" Duration: {duration:.1f}s")
|
|
logger.info(f" Trades processed: {session_data['trades_processed']}")
|
|
logger.info(f" Success rate: {session_data['success_rate']:.2%}")
|
|
|
|
# Remove from active sessions
|
|
completed_session = self.training_sessions.pop(session_id)
|
|
|
|
return completed_session
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ending training session: {e}")
|
|
return {}
|
|
|
|
def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False):
|
|
"""Update training session statistics"""
|
|
try:
|
|
if session_id not in self.training_sessions:
|
|
return
|
|
|
|
session = self.training_sessions[session_id]
|
|
|
|
if trade_processed:
|
|
session['trades_processed'] += 1
|
|
|
|
if training_success:
|
|
session['successful_trainings'] += 1
|
|
else:
|
|
session['failed_trainings'] += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating session stats: {e}") |