training fixes and enhancements wip
This commit is contained in:
@ -3123,9 +3123,13 @@ class CleanTradingDashboard:
|
||||
if len(self.recent_decisions) > 200:
|
||||
self.recent_decisions = self.recent_decisions[-200:]
|
||||
|
||||
# Train ALL models on the signal (if executed)
|
||||
# Train ALL models on EVERY prediction result (not just executed ones)
|
||||
# This ensures models learn from all predictions, not just successful trades
|
||||
self._train_all_models_on_prediction(signal)
|
||||
|
||||
# Additional training weight for executed signals
|
||||
if signal['executed']:
|
||||
self._train_all_models_on_signal(signal)
|
||||
self._train_all_models_on_executed_signal(signal)
|
||||
|
||||
# Log signal processing
|
||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||
@ -3135,33 +3139,118 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dashboard signal: {e}")
|
||||
|
||||
def _train_all_models_on_signal(self, signal: Dict):
|
||||
"""Train ALL models on executed trade signal - Comprehensive training system"""
|
||||
def _train_all_models_on_prediction(self, signal: Dict):
|
||||
"""Train ALL models on EVERY prediction result - Comprehensive learning system"""
|
||||
try:
|
||||
# Get prediction outcome based on immediate price movement
|
||||
prediction_outcome = self._get_prediction_outcome_for_training(signal)
|
||||
if not prediction_outcome:
|
||||
return
|
||||
|
||||
# 1. Train DQN model on prediction outcome
|
||||
self._train_dqn_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 2. Train CNN model on prediction outcome
|
||||
self._train_cnn_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 3. Train Transformer model on prediction outcome
|
||||
self._train_transformer_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 4. Train COB RL model on prediction outcome
|
||||
self._train_cob_rl_on_prediction(signal, prediction_outcome)
|
||||
|
||||
# 5. Train Decision Fusion model on prediction outcome
|
||||
self._train_decision_fusion_on_prediction(signal, prediction_outcome)
|
||||
|
||||
logger.debug(f"Trained all models on {signal['action']} prediction with outcome: {prediction_outcome['accuracy']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training models on prediction: {e}")
|
||||
|
||||
def _train_all_models_on_executed_signal(self, signal: Dict):
|
||||
"""Train ALL models on executed trade signal with enhanced weight - Comprehensive training system"""
|
||||
try:
|
||||
# Get trade outcome for training
|
||||
trade_outcome = self._get_trade_outcome_for_training(signal)
|
||||
if not trade_outcome:
|
||||
return
|
||||
|
||||
# 1. Train DQN model
|
||||
self._train_dqn_on_signal(signal, trade_outcome)
|
||||
# Enhanced training weight for executed signals (10x more important)
|
||||
enhanced_outcome = trade_outcome.copy()
|
||||
enhanced_outcome['training_weight'] = 10.0 # 10x weight for executed trades
|
||||
|
||||
# 2. Train CNN model
|
||||
self._train_cnn_on_signal(signal, trade_outcome)
|
||||
# 1. Train DQN model with enhanced weight
|
||||
self._train_dqn_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
# 3. Train Transformer model
|
||||
self._train_transformer_on_signal(signal, trade_outcome)
|
||||
# 2. Train CNN model with enhanced weight
|
||||
self._train_cnn_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
# 4. Train COB RL model
|
||||
self._train_cob_rl_on_signal(signal, trade_outcome)
|
||||
# 3. Train Transformer model with enhanced weight
|
||||
self._train_transformer_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
# 5. Train Decision Fusion model
|
||||
self._train_decision_fusion_on_signal(signal, trade_outcome)
|
||||
# 4. Train COB RL model with enhanced weight
|
||||
self._train_cob_rl_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
logger.debug(f"Trained all models on {signal['action']} signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||
# 5. Train Decision Fusion model with enhanced weight
|
||||
self._train_decision_fusion_on_executed_signal(signal, enhanced_outcome)
|
||||
|
||||
logger.info(f"Enhanced training completed on {signal['action']} executed signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training models on signal: {e}")
|
||||
logger.debug(f"Error training models on executed signal: {e}")
|
||||
|
||||
def _train_all_models_on_signal(self, signal: Dict):
|
||||
"""Legacy method - now redirects to new training system"""
|
||||
self._train_all_models_on_prediction(signal)
|
||||
|
||||
def _get_prediction_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||
"""Get prediction outcome based on immediate price movement validation"""
|
||||
try:
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
action = signal.get('action', 'HOLD')
|
||||
confidence = signal.get('confidence', 0.0)
|
||||
prediction_time = signal.get('timestamp', datetime.now())
|
||||
|
||||
# Get current price to validate prediction
|
||||
current_price = self._get_current_price(symbol)
|
||||
if not current_price:
|
||||
return None
|
||||
|
||||
# Get price at prediction time (or recent price if not available)
|
||||
prediction_price = signal.get('price', current_price)
|
||||
|
||||
# Calculate immediate price movement (within 1-5 minutes)
|
||||
price_change = ((current_price - prediction_price) / prediction_price) * 100
|
||||
|
||||
# Determine if prediction was accurate based on action and price movement
|
||||
prediction_accurate = False
|
||||
if action == 'BUY' and price_change > 0.1: # 0.1% positive movement
|
||||
prediction_accurate = True
|
||||
elif action == 'SELL' and price_change < -0.1: # 0.1% negative movement
|
||||
prediction_accurate = True
|
||||
elif action == 'HOLD' and abs(price_change) < 0.2: # Stable price
|
||||
prediction_accurate = True
|
||||
|
||||
# Calculate accuracy score (0.0 to 1.0)
|
||||
accuracy_score = 0.5 # Base neutral score
|
||||
if prediction_accurate:
|
||||
accuracy_score = min(1.0, 0.5 + (confidence * 0.5)) # Higher confidence = higher score
|
||||
else:
|
||||
accuracy_score = max(0.0, 0.5 - (confidence * 0.5)) # Higher confidence = lower score for wrong predictions
|
||||
|
||||
return {
|
||||
'accuracy': accuracy_score,
|
||||
'price_change': price_change,
|
||||
'prediction_accurate': prediction_accurate,
|
||||
'confidence': confidence,
|
||||
'action': action,
|
||||
'prediction_time': prediction_time,
|
||||
'validation_time': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting prediction outcome: {e}")
|
||||
return None
|
||||
|
||||
def _get_trade_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||
"""Get trade outcome for training - either from completed trade or position change"""
|
||||
@ -3213,8 +3302,8 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error getting trade outcome: {e}")
|
||||
return None
|
||||
|
||||
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train DQN agent on executed signal with trade outcome"""
|
||||
def _train_dqn_on_prediction(self, signal: Dict, prediction_outcome: Dict):
|
||||
"""Train DQN agent on prediction outcome (every prediction, not just executed trades)"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return
|
||||
@ -3223,31 +3312,66 @@ class CleanTradingDashboard:
|
||||
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||
|
||||
# Calculate reward based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
reward = pnl * 100 # Scale reward for better learning
|
||||
|
||||
# Create next state (simplified)
|
||||
next_state_features = state_features.copy() # In real implementation, this would be the next market state
|
||||
# Calculate reward based on prediction accuracy
|
||||
accuracy = prediction_outcome.get('accuracy', 0.5)
|
||||
confidence = signal.get('confidence', 0.5)
|
||||
reward = (accuracy - 0.5) * 2.0 # Convert to [-1, 1] range
|
||||
|
||||
# Store experience in DQN memory
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
self.orchestrator.rl_agent.remember(
|
||||
state_features, action, reward, next_state_features, done=True
|
||||
state_features, action, reward, state_features, done=True
|
||||
)
|
||||
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay(batch_size=32)
|
||||
loss = self.orchestrator.rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
logger.debug(f"DQN trained on prediction - loss: {loss:.4f}, accuracy: {accuracy:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on signal: {e}")
|
||||
logger.debug(f"Error training DQN on prediction: {e}")
|
||||
|
||||
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train CNN model on executed signal with trade outcome"""
|
||||
def _train_dqn_on_executed_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train DQN agent on executed signal with enhanced weight"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return
|
||||
|
||||
# Create training data for DQN
|
||||
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||
|
||||
# Calculate enhanced reward based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
training_weight = trade_outcome.get('training_weight', 1.0)
|
||||
reward = pnl * 100 * training_weight # Enhanced reward for executed trades
|
||||
|
||||
# Store experience in DQN memory with multiple entries for enhanced learning
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
# Store multiple copies for enhanced learning
|
||||
for _ in range(int(training_weight)):
|
||||
self.orchestrator.rl_agent.remember(
|
||||
state_features, action, reward, state_features, done=True
|
||||
)
|
||||
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.info(f"DQN enhanced training on executed signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on executed signal: {e}")
|
||||
|
||||
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Legacy method - redirects to new training system"""
|
||||
self._train_dqn_on_prediction(signal, trade_outcome)
|
||||
|
||||
def _train_cnn_on_prediction(self, signal: Dict, prediction_outcome: Dict):
|
||||
"""Train CNN model on prediction outcome (every prediction, not just executed trades)"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
return
|
||||
@ -3261,25 +3385,64 @@ class CleanTradingDashboard:
|
||||
if not market_features:
|
||||
return
|
||||
|
||||
# Create target based on trade outcome
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
target = 1.0 if pnl > 0 else 0.0 # Binary classification: profitable vs not
|
||||
# Create target based on prediction accuracy
|
||||
accuracy = prediction_outcome.get('accuracy', 0.5)
|
||||
target = accuracy # Use accuracy as target (0.0 to 1.0)
|
||||
|
||||
# Prepare training data
|
||||
features = market_features.get('features', [])
|
||||
if features:
|
||||
# Convert to tensor format (simplified)
|
||||
import numpy as np
|
||||
feature_tensor = np.array(features, dtype=np.float32)
|
||||
target_tensor = np.array([target], dtype=np.float32)
|
||||
|
||||
# Train CNN model (if it has training method)
|
||||
# Train CNN model
|
||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||
logger.debug(f"CNN trained on signal - loss: {loss:.4f}, target: {target}")
|
||||
logger.debug(f"CNN trained on prediction - loss: {loss:.4f}, accuracy: {accuracy:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on signal: {e}")
|
||||
logger.debug(f"Error training CNN on prediction: {e}")
|
||||
|
||||
def _train_cnn_on_executed_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train CNN model on executed signal with enhanced weight"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
return
|
||||
|
||||
# Create training data for CNN
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
current_price = signal.get('price', 0)
|
||||
|
||||
# Get market features
|
||||
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||
if not market_features:
|
||||
return
|
||||
|
||||
# Create target based on trade outcome with enhanced weight
|
||||
pnl = trade_outcome.get('pnl', 0)
|
||||
training_weight = trade_outcome.get('training_weight', 1.0)
|
||||
target = 1.0 if pnl > 0 else 0.0
|
||||
|
||||
# Prepare training data
|
||||
features = market_features.get('features', [])
|
||||
if features:
|
||||
import numpy as np
|
||||
feature_tensor = np.array(features, dtype=np.float32)
|
||||
target_tensor = np.array([target], dtype=np.float32)
|
||||
|
||||
# Train CNN model with multiple passes for enhanced learning
|
||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||
for _ in range(int(training_weight)):
|
||||
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||
logger.info(f"CNN enhanced training on executed signal - loss: {loss:.4f}, pnl: {pnl:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on executed signal: {e}")
|
||||
|
||||
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Legacy method - redirects to new training system"""
|
||||
self._train_cnn_on_prediction(signal, trade_outcome)
|
||||
|
||||
def _train_transformer_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||
"""Train Transformer model on executed signal with trade outcome"""
|
||||
@ -3342,7 +3505,7 @@ class CleanTradingDashboard:
|
||||
# Trigger training if enough samples
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 32:
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
||||
loss = self.orchestrator.cob_rl_agent.replay(batch_size=32)
|
||||
loss = self.orchestrator.cob_rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"COB RL trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
@ -3999,7 +4162,7 @@ class CleanTradingDashboard:
|
||||
# Cold start training moved to core.training_integration.TrainingIntegration
|
||||
|
||||
def _clear_session(self):
|
||||
"""Clear session data"""
|
||||
"""Clear session data and persistent files"""
|
||||
try:
|
||||
# Reset session metrics
|
||||
self.session_pnl = 0.0
|
||||
@ -4016,11 +4179,96 @@ class CleanTradingDashboard:
|
||||
self.current_position = None
|
||||
self.pending_trade_case_id = None # Clear pending trade tracking
|
||||
|
||||
logger.info("Session data cleared")
|
||||
# Clear persistent trade log files
|
||||
self._clear_trade_logs()
|
||||
|
||||
# Clear orchestrator state if available
|
||||
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||
self._clear_orchestrator_state()
|
||||
|
||||
logger.info("Session data and trade logs cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing session: {e}")
|
||||
|
||||
def _clear_trade_logs(self):
|
||||
"""Clear all trade log files"""
|
||||
try:
|
||||
import os
|
||||
import glob
|
||||
|
||||
# Clear trade_logs directory
|
||||
trade_logs_dir = "trade_logs"
|
||||
if os.path.exists(trade_logs_dir):
|
||||
# Remove all CSV files in trade_logs
|
||||
csv_files = glob.glob(os.path.join(trade_logs_dir, "*.csv"))
|
||||
for file in csv_files:
|
||||
try:
|
||||
os.remove(file)
|
||||
logger.info(f"Deleted trade log: {file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Remove any .log files in trade_logs
|
||||
log_files = glob.glob(os.path.join(trade_logs_dir, "*.log"))
|
||||
for file in log_files:
|
||||
try:
|
||||
os.remove(file)
|
||||
logger.info(f"Deleted trade log: {file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Clear recent log files in logs directory
|
||||
logs_dir = "logs"
|
||||
if os.path.exists(logs_dir):
|
||||
# Remove recent trading logs (keep older system logs)
|
||||
recent_logs = [
|
||||
"enhanced_trading.log",
|
||||
"realtime_rl_cob_trader.log",
|
||||
"simple_cob_dashboard.log",
|
||||
"integrated_rl_cob_system.log",
|
||||
"optimized_cob_system.log"
|
||||
]
|
||||
|
||||
for log_file in recent_logs:
|
||||
log_path = os.path.join(logs_dir, log_file)
|
||||
if os.path.exists(log_path):
|
||||
try:
|
||||
# Truncate the file instead of deleting to preserve file handles
|
||||
with open(log_path, 'w') as f:
|
||||
f.write("") # Clear file content
|
||||
logger.info(f"Cleared log file: {log_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear {log_path}: {e}")
|
||||
|
||||
logger.info("Trade logs cleared successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing trade logs: {e}")
|
||||
|
||||
def _clear_orchestrator_state(self):
|
||||
"""Clear orchestrator state and recent predictions"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'recent_decisions'):
|
||||
self.orchestrator.recent_decisions = {}
|
||||
|
||||
if hasattr(self.orchestrator, 'recent_dqn_predictions'):
|
||||
for symbol in self.orchestrator.recent_dqn_predictions:
|
||||
self.orchestrator.recent_dqn_predictions[symbol].clear()
|
||||
|
||||
if hasattr(self.orchestrator, 'recent_cnn_predictions'):
|
||||
for symbol in self.orchestrator.recent_cnn_predictions:
|
||||
self.orchestrator.recent_cnn_predictions[symbol].clear()
|
||||
|
||||
if hasattr(self.orchestrator, 'prediction_accuracy_history'):
|
||||
for symbol in self.orchestrator.prediction_accuracy_history:
|
||||
self.orchestrator.prediction_accuracy_history[symbol].clear()
|
||||
|
||||
logger.info("Orchestrator state cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing orchestrator state: {e}")
|
||||
|
||||
def _store_all_models(self) -> bool:
|
||||
"""Store all current models to persistent storage"""
|
||||
try:
|
||||
@ -6112,7 +6360,7 @@ class CleanTradingDashboard:
|
||||
# Perform training step if agent has replay method
|
||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
||||
loss = cob_rl_agent.replay()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
|
Reference in New Issue
Block a user