test cases

This commit is contained in:
Dobromir Popov
2025-06-25 14:45:37 +03:00
parent 4a1170d593
commit 4afa147bd1
5 changed files with 1039 additions and 247 deletions

View File

@ -0,0 +1,304 @@
#!/usr/bin/env python3
"""
Training Integration - Handles cold start training and model learning integration
Manages:
- Cold start training triggers from trade outcomes
- Reward calculation based on P&L
- Integration with DQN, CNN, and COB RL models
- Training session management
"""
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
import numpy as np
logger = logging.getLogger(__name__)
class TrainingIntegration:
"""Manages training integration for cold start learning"""
def __init__(self, orchestrator=None):
self.orchestrator = orchestrator
self.training_sessions = {}
self.min_confidence_threshold = 0.3
logger.info("TrainingIntegration initialized")
def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool:
"""Trigger cold start training when trades close with known outcomes"""
try:
if not trade_record.get('model_inputs_at_entry'):
logger.warning("No model inputs captured for training - skipping")
return False
pnl = trade_record.get('pnl', 0)
confidence = trade_record.get('confidence', 0)
logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}")
# Calculate training reward based on P&L and confidence
reward = self._calculate_training_reward(pnl, confidence)
# Train DQN on trade outcome
dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward)
# Train CNN if available (placeholder for now)
cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward)
# Train COB RL if available (placeholder for now)
cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward)
# Log training results
training_success = any([dqn_success, cnn_success, cob_success])
if training_success:
logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}")
else:
logger.warning("Cold start training failed for all models")
return training_success
except Exception as e:
logger.error(f"Error in cold start training: {e}")
return False
def _calculate_training_reward(self, pnl: float, confidence: float) -> float:
"""Calculate training reward based on P&L and confidence"""
try:
# Base reward is proportional to P&L
base_reward = pnl
# Adjust for confidence - penalize high confidence wrong predictions more
if pnl < 0 and confidence > 0.7:
# High confidence loss - significant negative reward
confidence_adjustment = -confidence * 2
elif pnl > 0 and confidence > 0.7:
# High confidence gain - boost reward
confidence_adjustment = confidence * 1.5
else:
# Low confidence - minimal adjustment
confidence_adjustment = 0
final_reward = base_reward + confidence_adjustment
# Normalize to [-1, 1] range for training stability
normalized_reward = np.tanh(final_reward / 10.0)
logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
return float(normalized_reward)
except Exception as e:
logger.error(f"Error calculating training reward: {e}")
return 0.0
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train DQN agent on trade outcome"""
try:
if not self.orchestrator:
logger.warning("No orchestrator available for DQN training")
return False
# Get DQN agent
if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent:
logger.warning("DQN agent not available for training")
return False
# Extract DQN state from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector')
if not dqn_state:
logger.warning("No DQN state available for training")
return False
# Convert action to DQN action index
action = trade_record.get('side', 'HOLD').upper()
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
action_idx = action_map.get(action, 2)
# Create next state (simplified - could be current market state)
next_state = dqn_state # Placeholder - should be state after trade
# Store experience in DQN memory
dqn_agent = self.orchestrator.dqn_agent
if hasattr(dqn_agent, 'store_experience'):
dqn_agent.store_experience(
state=np.array(dqn_state),
action=action_idx,
reward=reward,
next_state=np.array(next_state),
done=True # Trade is complete
)
# Trigger training if enough experiences
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
dqn_agent.replay(batch_size=32)
logger.info("DQN training step completed")
return True
else:
logger.warning("DQN agent doesn't support experience storage")
return False
except Exception as e:
logger.error(f"Error training DQN on trade outcome: {e}")
return False
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train CNN on trade outcome (placeholder)"""
try:
if not self.orchestrator:
return False
# Check if CNN is available
if not hasattr(self.orchestrator, 'williams_cnn') or not self.orchestrator.williams_cnn:
logger.debug("CNN not available for training")
return False
# Get CNN features from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
cnn_features = model_inputs.get('cnn_features')
cnn_predictions = model_inputs.get('cnn_predictions')
if not cnn_features or not cnn_predictions:
logger.debug("No CNN features available for training")
return False
# CNN training would go here - requires more specific implementation
# For now, just log that we could train CNN
logger.debug(f"CNN training opportunity: features={len(cnn_features)}, predictions={len(cnn_predictions)}")
return True
except Exception as e:
logger.debug(f"Error in CNN training: {e}")
return False
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train COB RL on trade outcome (placeholder)"""
try:
if not self.orchestrator:
return False
# Check if COB integration is available
if not hasattr(self.orchestrator, 'cob_integration') or not self.orchestrator.cob_integration:
logger.debug("COB integration not available for training")
return False
# Get COB features from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
cob_features = model_inputs.get('cob_features')
if not cob_features:
logger.debug("No COB features available for training")
return False
# COB RL training would go here - requires more specific implementation
# For now, just log that we could train COB RL
logger.debug(f"COB RL training opportunity: features={len(cob_features)}")
return True
except Exception as e:
logger.debug(f"Error in COB RL training: {e}")
return False
def get_training_status(self) -> Dict[str, Any]:
"""Get current training integration status"""
try:
status = {
'orchestrator_available': self.orchestrator is not None,
'training_sessions': len(self.training_sessions),
'last_update': datetime.now().isoformat()
}
if self.orchestrator:
status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None
status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None
status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None
return status
except Exception as e:
logger.error(f"Error getting training status: {e}")
return {'error': str(e)}
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
"""Start a new training session"""
try:
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
session_data = {
'session_id': session_id,
'session_name': session_name,
'start_time': datetime.now().isoformat(),
'config': config or {},
'trades_processed': 0,
'successful_trainings': 0,
'failed_trainings': 0
}
self.training_sessions[session_id] = session_data
logger.info(f"Started training session: {session_id}")
return session_id
except Exception as e:
logger.error(f"Error starting training session: {e}")
return ""
def end_training_session(self, session_id: str) -> Dict[str, Any]:
"""End a training session and return summary"""
try:
if session_id not in self.training_sessions:
logger.warning(f"Training session not found: {session_id}")
return {}
session_data = self.training_sessions[session_id]
session_data['end_time'] = datetime.now().isoformat()
# Calculate session duration
start_time = datetime.fromisoformat(session_data['start_time'])
end_time = datetime.fromisoformat(session_data['end_time'])
duration = (end_time - start_time).total_seconds()
session_data['duration_seconds'] = duration
# Calculate success rate
total_attempts = session_data['successful_trainings'] + session_data['failed_trainings']
session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0
logger.info(f"Ended training session: {session_id}")
logger.info(f" Duration: {duration:.1f}s")
logger.info(f" Trades processed: {session_data['trades_processed']}")
logger.info(f" Success rate: {session_data['success_rate']:.2%}")
# Remove from active sessions
completed_session = self.training_sessions.pop(session_id)
return completed_session
except Exception as e:
logger.error(f"Error ending training session: {e}")
return {}
def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False):
"""Update training session statistics"""
try:
if session_id not in self.training_sessions:
return
session = self.training_sessions[session_id]
if trade_processed:
session['trades_processed'] += 1
if training_success:
session['successful_trainings'] += 1
else:
session['failed_trainings'] += 1
except Exception as e:
logger.error(f"Error updating session stats: {e}")