Compare commits
3 Commits
6ca19f4536
...
kiro-v
Author | SHA1 | Date | |
---|---|---|---|
fde370fa1b | |||
14086a898e | |||
36f429a0e2 |
@ -1872,32 +1872,67 @@ class EnhancedRealtimeTrainingSystem:
|
||||
def _log_training_progress(self):
|
||||
"""Log comprehensive training progress"""
|
||||
try:
|
||||
stats = {
|
||||
'iteration': self.training_iteration,
|
||||
'experience_buffer': len(self.experience_buffer),
|
||||
'priority_buffer': len(self.priority_buffer),
|
||||
'dqn_memory': self._get_dqn_memory_size(),
|
||||
'data_streams': {
|
||||
'ohlcv_1m': len(self.real_time_data['ohlcv_1m']),
|
||||
'ticks': len(self.real_time_data['ticks']),
|
||||
'cob_snapshots': len(self.real_time_data['cob_snapshots']),
|
||||
'market_events': len(self.real_time_data['market_events'])
|
||||
}
|
||||
}
|
||||
logger.info("=" * 60)
|
||||
logger.info("ENHANCED TRAINING SYSTEM PROGRESS REPORT")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Basic training statistics
|
||||
logger.info(f"Training Iteration: {self.training_iteration}")
|
||||
logger.info(f"Experience Buffer: {len(self.experience_buffer)} samples")
|
||||
logger.info(f"Priority Buffer: {len(self.priority_buffer)} samples")
|
||||
logger.info(f"DQN Memory: {self._get_dqn_memory_size()} experiences")
|
||||
|
||||
# Data stream statistics
|
||||
logger.info("\nDATA STREAMS:")
|
||||
logger.info(f" OHLCV 1m: {len(self.real_time_data['ohlcv_1m'])} records")
|
||||
logger.info(f" Ticks: {len(self.real_time_data['ticks'])} records")
|
||||
logger.info(f" COB Snapshots: {len(self.real_time_data['cob_snapshots'])} records")
|
||||
logger.info(f" Market Events: {len(self.real_time_data['market_events'])} records")
|
||||
|
||||
# Performance metrics
|
||||
logger.info("\nPERFORMANCE METRICS:")
|
||||
if self.performance_history['dqn_losses']:
|
||||
stats['dqn_avg_loss'] = np.mean(list(self.performance_history['dqn_losses'])[-10:])
|
||||
dqn_avg_loss = np.mean(list(self.performance_history['dqn_losses'])[-10:])
|
||||
dqn_recent_loss = list(self.performance_history['dqn_losses'])[-1] if self.performance_history['dqn_losses'] else 0
|
||||
logger.info(f" DQN Average Loss (10): {dqn_avg_loss:.4f}")
|
||||
logger.info(f" DQN Recent Loss: {dqn_recent_loss:.4f}")
|
||||
|
||||
if self.performance_history['cnn_losses']:
|
||||
stats['cnn_avg_loss'] = np.mean(list(self.performance_history['cnn_losses'])[-10:])
|
||||
cnn_avg_loss = np.mean(list(self.performance_history['cnn_losses'])[-10:])
|
||||
cnn_recent_loss = list(self.performance_history['cnn_losses'])[-1] if self.performance_history['cnn_losses'] else 0
|
||||
logger.info(f" CNN Average Loss (10): {cnn_avg_loss:.4f}")
|
||||
logger.info(f" CNN Recent Loss: {cnn_recent_loss:.4f}")
|
||||
|
||||
if self.performance_history['validation_scores']:
|
||||
stats['validation_score'] = self.performance_history['validation_scores'][-1]['combined_score']
|
||||
validation_score = self.performance_history['validation_scores'][-1]['combined_score']
|
||||
logger.info(f" Validation Score: {validation_score:.3f}")
|
||||
|
||||
logger.info(f"ENHANCED TRAINING PROGRESS: {stats}")
|
||||
# Training configuration
|
||||
logger.info("\nTRAINING CONFIGURATION:")
|
||||
logger.info(f" DQN Training Interval: {self.training_config['dqn_training_interval']} iterations")
|
||||
logger.info(f" CNN Training Interval: {self.training_config['cnn_training_interval']} iterations")
|
||||
logger.info(f" COB RL Training Interval: {self.training_config['cob_rl_training_interval']} iterations")
|
||||
logger.info(f" Validation Interval: {self.training_config['validation_interval']} iterations")
|
||||
|
||||
# Prediction statistics
|
||||
if hasattr(self, 'prediction_history') and self.prediction_history:
|
||||
logger.info("\nPREDICTION STATISTICS:")
|
||||
recent_predictions = list(self.prediction_history)[-10:] if len(self.prediction_history) > 10 else list(self.prediction_history)
|
||||
logger.info(f" Recent Predictions: {len(recent_predictions)}")
|
||||
if recent_predictions:
|
||||
avg_confidence = np.mean([p.get('confidence', 0) for p in recent_predictions])
|
||||
logger.info(f" Average Confidence: {avg_confidence:.3f}")
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Periodic comprehensive logging (every 20th iteration)
|
||||
if self.training_iteration % 20 == 0:
|
||||
logger.info("PERIODIC ENHANCED TRAINING COMPREHENSIVE LOG:")
|
||||
if hasattr(self.orchestrator, 'log_model_statistics'):
|
||||
self.orchestrator.log_model_statistics(detailed=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error logging progress: {e}")
|
||||
logger.error(f"Error logging enhanced training progress: {e}")
|
||||
|
||||
def _validation_worker(self):
|
||||
"""Background worker for continuous validation"""
|
||||
|
@ -58,6 +58,7 @@ from core.extrema_trainer import (
|
||||
from utils.inference_logger import get_inference_logger, log_model_inference
|
||||
from utils.database_manager import get_database_manager
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from safe_logging import setup_training_logger
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
try:
|
||||
@ -300,6 +301,9 @@ class TradingOrchestrator:
|
||||
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Initialize training logger
|
||||
self.training_logger = setup_training_logger()
|
||||
|
||||
# Configuration - AGGRESSIVE for more training data
|
||||
self.confidence_threshold = self.config.orchestrator.get(
|
||||
"confidence_threshold", 0.15
|
||||
@ -536,7 +540,6 @@ class TradingOrchestrator:
|
||||
self._initialize_decision_fusion() # Initialize fusion system
|
||||
self._initialize_transformer_model() # Initialize transformer model
|
||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
try:
|
||||
@ -1309,7 +1312,6 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error clearing orchestrator session data: {e}")
|
||||
|
||||
def sync_model_states_with_dashboard(self):
|
||||
"""Sync model states with current dashboard values"""
|
||||
# Update based on the dashboard stats provided
|
||||
@ -1412,13 +1414,20 @@ class TradingOrchestrator:
|
||||
ui_state = json.load(f)
|
||||
if "model_toggle_states" in ui_state:
|
||||
self.model_toggle_states.update(ui_state["model_toggle_states"])
|
||||
# Validate and clean the loaded states
|
||||
self._validate_model_toggle_states()
|
||||
logger.info(f"UI state loaded from {self.ui_state_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading UI state: {e}")
|
||||
# If loading fails, ensure we have valid default states
|
||||
self._validate_model_toggle_states()
|
||||
|
||||
def _save_ui_state(self):
|
||||
"""Save UI state to file"""
|
||||
try:
|
||||
# Validate and clean model toggle states before saving
|
||||
self._validate_model_toggle_states()
|
||||
|
||||
os.makedirs(os.path.dirname(self.ui_state_file), exist_ok=True)
|
||||
ui_state = {
|
||||
"model_toggle_states": self.model_toggle_states,
|
||||
@ -1429,6 +1438,36 @@ class TradingOrchestrator:
|
||||
logger.debug(f"UI state saved to {self.ui_state_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving UI state: {e}")
|
||||
|
||||
def _validate_model_toggle_states(self):
|
||||
"""Validate and clean model toggle states to ensure proper boolean values"""
|
||||
try:
|
||||
for model_name, toggle_state in self.model_toggle_states.items():
|
||||
if not isinstance(toggle_state, dict):
|
||||
logger.warning(f"Invalid toggle state for {model_name}, resetting to defaults")
|
||||
self.model_toggle_states[model_name] = {"inference_enabled": True, "training_enabled": True}
|
||||
continue
|
||||
|
||||
# Ensure inference_enabled is boolean
|
||||
if "inference_enabled" in toggle_state:
|
||||
if not isinstance(toggle_state["inference_enabled"], bool):
|
||||
logger.warning(f"Invalid inference_enabled value for {model_name}: {toggle_state['inference_enabled']}, setting to True")
|
||||
toggle_state["inference_enabled"] = True
|
||||
|
||||
# Ensure training_enabled is boolean
|
||||
if "training_enabled" in toggle_state:
|
||||
if not isinstance(toggle_state["training_enabled"], bool):
|
||||
logger.warning(f"Invalid training_enabled value for {model_name}: {toggle_state['training_enabled']}, setting to True")
|
||||
toggle_state["training_enabled"] = True
|
||||
|
||||
# Ensure both keys exist
|
||||
if "inference_enabled" not in toggle_state:
|
||||
toggle_state["inference_enabled"] = True
|
||||
if "training_enabled" not in toggle_state:
|
||||
toggle_state["training_enabled"] = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating model toggle states: {e}")
|
||||
|
||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||
"""Get toggle state for a model"""
|
||||
@ -2055,7 +2094,6 @@ class TradingOrchestrator:
|
||||
f"Decision callback registered: {callback.__name__ if hasattr(callback, '__name__') else 'unnamed'}"
|
||||
)
|
||||
return True
|
||||
|
||||
async def make_trading_decision(self, symbol: str) -> Optional[TradingDecision]:
|
||||
"""
|
||||
Make a trading decision for a symbol by combining all registered model outputs
|
||||
@ -2248,7 +2286,7 @@ class TradingOrchestrator:
|
||||
predicted_price_vector = prediction.metadata['price_direction']
|
||||
|
||||
# Calculate sophisticated reward using the new PnL penalty/reward system
|
||||
sophisticated_reward, was_correct = self._calculate_sophisticated_reward(
|
||||
sophisticated_reward, was_correct, should_skip = self._calculate_sophisticated_reward(
|
||||
predicted_action=prediction.action,
|
||||
prediction_confidence=prediction.confidence,
|
||||
price_change_pct=price_change_pct,
|
||||
@ -2260,6 +2298,11 @@ class TradingOrchestrator:
|
||||
predicted_price_vector=predicted_price_vector
|
||||
)
|
||||
|
||||
# Skip training if this is a neutral action (no position + HOLD)
|
||||
if should_skip:
|
||||
logger.debug(f"Skipping training for neutral action: {prediction.action} (no position)")
|
||||
continue
|
||||
|
||||
# Create training record for the new training system
|
||||
training_record = {
|
||||
"symbol": symbol,
|
||||
@ -2668,6 +2711,59 @@ class TradingOrchestrator:
|
||||
return {}
|
||||
|
||||
def log_model_statistics(self, detailed: bool = False):
|
||||
"""Log comprehensive model statistics and performance metrics"""
|
||||
try:
|
||||
self.training_logger.info("=" * 80)
|
||||
self.training_logger.info("COMPREHENSIVE MODEL PERFORMANCE SUMMARY")
|
||||
self.training_logger.info("=" * 80)
|
||||
|
||||
# Log overall system performance
|
||||
if hasattr(self, 'model_performance'):
|
||||
self.training_logger.info("OVERALL MODEL PERFORMANCE:")
|
||||
for model_name, perf in self.model_performance.items():
|
||||
accuracy = perf.get('accuracy', 0)
|
||||
total = perf.get('total', 0)
|
||||
correct = perf.get('correct', 0)
|
||||
self.training_logger.info(f" {model_name.upper()}: {accuracy:.1%} ({correct}/{total})")
|
||||
|
||||
# Log detailed model statistics
|
||||
if hasattr(self, 'model_statistics'):
|
||||
self.training_logger.info("\nDETAILED MODEL STATISTICS:")
|
||||
for model_name, stats in self.model_statistics.items():
|
||||
self.training_logger.info(f" {model_name.upper()}:")
|
||||
self.training_logger.info(f" Inferences: {stats.total_inferences}")
|
||||
self.training_logger.info(f" Trainings: {stats.total_trainings}")
|
||||
self.training_logger.info(f" Current loss: {stats.current_loss:.4f}" if stats.current_loss else " Current loss: N/A")
|
||||
self.training_logger.info(f" Best loss: {stats.best_loss:.4f}" if stats.best_loss else " Best loss: N/A")
|
||||
self.training_logger.info(f" Average loss: {stats.average_loss:.4f}" if stats.average_loss else " Average loss: N/A")
|
||||
self.training_logger.info(f" Inference rate: {stats.inference_rate_per_minute:.1f}/min")
|
||||
self.training_logger.info(f" Training rate: {stats.training_rate_per_minute:.1f}/min")
|
||||
self.training_logger.info(f" Avg inference time: {stats.average_inference_time_ms:.1f}ms")
|
||||
self.training_logger.info(f" Avg training time: {stats.average_training_time_ms:.1f}ms")
|
||||
|
||||
# Log decision fusion performance
|
||||
if hasattr(self, 'decision_fusion_enabled') and self.decision_fusion_enabled:
|
||||
self.training_logger.info("\nDECISION FUSION PERFORMANCE:")
|
||||
self.training_logger.info(f" Mode: {getattr(self, 'decision_fusion_mode', 'unknown')}")
|
||||
self.training_logger.info(f" Decisions made: {getattr(self, 'decision_fusion_decisions_count', 0)}")
|
||||
self.training_logger.info(f" Training samples: {len(getattr(self, 'decision_fusion_training_data', []))}")
|
||||
|
||||
# Log enhanced training system status
|
||||
if hasattr(self, 'enhanced_training_system'):
|
||||
self.training_logger.info("\nENHANCED TRAINING SYSTEM:")
|
||||
if self.enhanced_training_system:
|
||||
stats = self.enhanced_training_system.get_training_statistics()
|
||||
self.training_logger.info(f" Status: {'ACTIVE' if stats.get('is_training', False) else 'INACTIVE'}")
|
||||
self.training_logger.info(f" Status: {'ACTIVE' if stats.get('is_training', False) else 'INACTIVE'}")
|
||||
self.training_logger.info(f" Iteration: {stats.get('training_iteration', 0)}")
|
||||
self.training_logger.info(f" Experience buffer: {stats.get('experience_buffer_size', 0)}")
|
||||
else:
|
||||
self.training_logger.info(" Status: NOT INITIALIZED")
|
||||
|
||||
self.training_logger.info("=" * 80)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging comprehensive statistics: {e}")
|
||||
"""Log current model statistics for monitoring"""
|
||||
try:
|
||||
if not self.model_statistics:
|
||||
@ -3283,7 +3379,7 @@ class TradingOrchestrator:
|
||||
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
# For older predictions, use a more conservative approach
|
||||
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds:.1f}s ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
price_outcome = f"Inference: ${inference_price:.2f} ({time_diff_seconds/60:.1f}m ago) -> Current: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
# Fall back to historical price comparison if no inference price
|
||||
try:
|
||||
@ -3367,7 +3463,7 @@ class TradingOrchestrator:
|
||||
if "price_direction" in prediction and prediction["price_direction"]:
|
||||
predicted_price_vector = prediction["price_direction"]
|
||||
|
||||
reward, _ = self._calculate_sophisticated_reward(
|
||||
reward, _, should_skip = self._calculate_sophisticated_reward(
|
||||
predicted_action,
|
||||
predicted_confidence,
|
||||
actual_price_change_pct,
|
||||
@ -3465,7 +3561,6 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN long-term training: {e}")
|
||||
|
||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||
"""Evaluate prediction outcome and train model"""
|
||||
try:
|
||||
@ -3528,7 +3623,7 @@ class TradingOrchestrator:
|
||||
if "price_direction" in prediction and prediction["price_direction"]:
|
||||
predicted_price_vector = prediction["price_direction"]
|
||||
|
||||
reward, was_correct = self._calculate_sophisticated_reward(
|
||||
reward, was_correct, should_skip = self._calculate_sophisticated_reward(
|
||||
predicted_action,
|
||||
prediction_confidence,
|
||||
price_change_pct,
|
||||
@ -3540,6 +3635,11 @@ class TradingOrchestrator:
|
||||
predicted_price_vector=predicted_price_vector,
|
||||
)
|
||||
|
||||
# Skip training and accuracy tracking if this is a neutral action (no position + HOLD)
|
||||
if should_skip:
|
||||
logger.debug(f"Skipping training and accuracy tracking for neutral action: {predicted_action} (no position)")
|
||||
return
|
||||
|
||||
# Update model performance tracking
|
||||
if model_name not in self.model_performance:
|
||||
self.model_performance[model_name] = {
|
||||
@ -3592,17 +3692,29 @@ class TradingOrchestrator:
|
||||
)
|
||||
|
||||
# Enhanced logging for training evaluation
|
||||
logger.info(f"Training evaluation for {model_name}:")
|
||||
logger.info(
|
||||
self.training_logger.info(f"TRAINING EVALUATION for {model_name.upper()}:")
|
||||
self.training_logger.info(
|
||||
f" Action: {predicted_action} | Confidence: {prediction_confidence:.3f}"
|
||||
)
|
||||
logger.info(
|
||||
self.training_logger.info(
|
||||
f" Price change: {price_change_pct:+.3f}% | Time: {time_diff_seconds:.1f}s"
|
||||
)
|
||||
logger.info(f" Reward: {reward:.4f} | Correct: {was_correct}")
|
||||
logger.info(
|
||||
self.training_logger.info(f" Reward: {reward:.4f} | Correct: {was_correct}")
|
||||
self.training_logger.info(
|
||||
f" Accuracy: {self.model_performance[model_name]['accuracy']:.1%} ({self.model_performance[model_name]['correct']}/{self.model_performance[model_name]['total']})"
|
||||
)
|
||||
|
||||
# Add detailed performance metrics logging
|
||||
if hasattr(self, 'model_statistics') and model_name in self.model_statistics:
|
||||
stats = self.model_statistics[model_name]
|
||||
self.training_logger.info(f" Model Statistics:")
|
||||
self.training_logger.info(f" Total inferences: {stats.total_inferences}")
|
||||
self.training_logger.info(f" Total trainings: {stats.total_trainings}")
|
||||
self.training_logger.info(f" Current loss: {stats.current_loss:.4f}" if stats.current_loss else " Current loss: N/A")
|
||||
self.training_logger.info(f" Best loss: {stats.best_loss:.4f}" if stats.best_loss else " Best loss: N/A")
|
||||
self.training_logger.info(f" Average loss: {stats.average_loss:.4f}" if stats.average_loss else " Average loss: N/A")
|
||||
self.training_logger.info(f" Inference rate: {stats.inference_rate_per_minute:.1f}/min")
|
||||
self.training_logger.info(f" Training rate: {stats.training_rate_per_minute:.1f}/min")
|
||||
|
||||
# Train the specific model based on sophisticated outcome
|
||||
await self._train_model_on_outcome(
|
||||
@ -3647,7 +3759,7 @@ class TradingOrchestrator:
|
||||
has_position: bool = None,
|
||||
current_position_pnl: float = 0.0,
|
||||
predicted_price_vector: dict = None,
|
||||
) -> tuple[float, bool]:
|
||||
) -> tuple[float, bool, bool]:
|
||||
"""
|
||||
Calculate sophisticated reward based on prediction accuracy, confidence, and price movement magnitude
|
||||
Now considers position status and current P&L when evaluating decisions
|
||||
@ -3666,7 +3778,10 @@ class TradingOrchestrator:
|
||||
predicted_price_vector: Dict with 'direction' (-1 to 1) and 'confidence' (0 to 1)
|
||||
|
||||
Returns:
|
||||
tuple: (reward, was_correct)
|
||||
tuple: (reward, was_correct, should_skip)
|
||||
- reward: The calculated reward value
|
||||
- was_correct: Whether the prediction was correct (True/False)
|
||||
- should_skip: Whether this should be skipped from accuracy calculations and training (True/False)
|
||||
"""
|
||||
try:
|
||||
# NOISE REDUCTION: Treat low-confidence signals as HOLD
|
||||
@ -3677,10 +3792,11 @@ class TradingOrchestrator:
|
||||
|
||||
# FEE-AWARE THRESHOLDS: Account for trading fees (0.05-0.06% per trade, ~0.12% round trip)
|
||||
fee_cost = 0.12 # 0.12% round trip fee cost
|
||||
movement_threshold = 0.15 # Minimum movement to be profitable after fees
|
||||
strong_movement_threshold = 0.5 # Strong movements - good profit potential
|
||||
rapid_movement_threshold = 1.0 # Rapid movements - excellent profit potential
|
||||
massive_movement_threshold = 2.0 # Massive movements - extraordinary profit potential
|
||||
pnl_threshold = 0.02 # 0.02% - minimum movement to include in PnL/accuracy calculations
|
||||
movement_threshold = 0.20 # Minimum movement to be profitable after fees (increased from 0.15%)
|
||||
strong_movement_threshold = 0.8 # Strong movements - good profit potential (increased from 0.5%)
|
||||
rapid_movement_threshold = 1.5 # Rapid movements - excellent profit potential (increased from 1.0%)
|
||||
massive_movement_threshold = 3.0 # Massive movements - extraordinary profit potential (increased from 2.0%)
|
||||
|
||||
# Determine current position status if not provided
|
||||
if has_position is None and symbol:
|
||||
@ -3693,29 +3809,48 @@ class TradingOrchestrator:
|
||||
|
||||
# Determine if prediction was directionally correct
|
||||
was_correct = False
|
||||
should_skip = False # Whether to skip from accuracy calculations and training
|
||||
directional_accuracy = 0.0
|
||||
|
||||
# Check if price movement is significant enough for PnL/accuracy calculation
|
||||
abs_price_change = abs(price_change_pct)
|
||||
include_in_accuracy = abs_price_change >= pnl_threshold
|
||||
|
||||
# Always check directional correctness for learning, but only include significant moves in accuracy
|
||||
direction_correct = False
|
||||
|
||||
if predicted_action == "BUY":
|
||||
# BUY signals need to overcome fee costs for profitability
|
||||
was_correct = price_change_pct > movement_threshold
|
||||
# Check directional correctness (always for learning)
|
||||
direction_correct = price_change_pct > 0
|
||||
|
||||
# ENHANCED FEE-AWARE REWARD STRUCTURE
|
||||
if price_change_pct > massive_movement_threshold:
|
||||
# Massive movements (2%+) - EXTRAORDINARY rewards for high confidence
|
||||
directional_accuracy = price_change_pct * 5.0 # 5x multiplier for massive moves
|
||||
if prediction_confidence > 0.8:
|
||||
directional_accuracy *= 2.0 # Additional 2x for high confidence (10x total)
|
||||
elif price_change_pct > rapid_movement_threshold:
|
||||
# Rapid movements (1%+) - EXCELLENT rewards for high confidence
|
||||
directional_accuracy = price_change_pct * 3.0 # 3x multiplier for rapid moves
|
||||
if prediction_confidence > 0.7:
|
||||
directional_accuracy *= 1.5 # Additional 1.5x for good confidence (4.5x total)
|
||||
elif price_change_pct > strong_movement_threshold:
|
||||
# Strong movements (0.5%+) - GOOD rewards
|
||||
directional_accuracy = price_change_pct * 2.0 # 2x multiplier for strong moves
|
||||
# Only consider "correct" for accuracy if movement is significant AND profitable
|
||||
if include_in_accuracy:
|
||||
was_correct = price_change_pct > movement_threshold
|
||||
else:
|
||||
# Small movements - minimal rewards (fees eat most profit)
|
||||
directional_accuracy = max(0, (price_change_pct - fee_cost)) * 0.5 # Penalty for fee cost
|
||||
# Small movement - learn direction but don't include in accuracy
|
||||
was_correct = None # Exclude from accuracy calculation
|
||||
|
||||
# ENHANCED FEE-AWARE REWARD STRUCTURE (only for significant movements)
|
||||
if include_in_accuracy:
|
||||
if price_change_pct > massive_movement_threshold:
|
||||
# Massive movements (2%+) - EXTRAORDINARY rewards for high confidence
|
||||
directional_accuracy = price_change_pct * 5.0 # 5x multiplier for massive moves
|
||||
if prediction_confidence > 0.8:
|
||||
directional_accuracy *= 2.0 # Additional 2x for high confidence (10x total)
|
||||
elif price_change_pct > rapid_movement_threshold:
|
||||
# Rapid movements (1%+) - EXCELLENT rewards for high confidence
|
||||
directional_accuracy = price_change_pct * 3.0 # 3x multiplier for rapid moves
|
||||
if prediction_confidence > 0.7:
|
||||
directional_accuracy *= 1.5 # Additional 1.5x for good confidence (4.5x total)
|
||||
elif price_change_pct > strong_movement_threshold:
|
||||
# Strong movements (0.5%+) - GOOD rewards
|
||||
directional_accuracy = price_change_pct * 2.0 # 2x multiplier for strong moves
|
||||
else:
|
||||
# Small but significant movements - minimal rewards (fees eat most profit)
|
||||
directional_accuracy = max(0, (price_change_pct - fee_cost)) * 0.5 # Penalty for fee cost
|
||||
else:
|
||||
# Very small movement - learn direction but minimal reward
|
||||
directional_accuracy = price_change_pct * 0.1 if direction_correct else -abs(price_change_pct) * 0.1
|
||||
|
||||
elif predicted_action == "SELL":
|
||||
# SELL signals need to overcome fee costs for profitability
|
||||
@ -3741,41 +3876,86 @@ class TradingOrchestrator:
|
||||
directional_accuracy = max(0, (abs_change - fee_cost)) * 0.5 # Penalty for fee cost
|
||||
|
||||
elif predicted_action == "HOLD":
|
||||
# HOLD evaluation with noise reduction - smaller rewards to reduce training noise
|
||||
# HOLD evaluation with position side awareness - considers LONG vs SHORT positions
|
||||
if has_position:
|
||||
# If we have a position, HOLD evaluation depends on P&L and price movement
|
||||
# Get position side to properly evaluate HOLD decisions
|
||||
position_side = self._get_position_side(symbol) if symbol else "LONG"
|
||||
|
||||
if current_position_pnl > 0: # Currently profitable position
|
||||
# Holding a profitable position is good if price continues favorably
|
||||
if price_change_pct > 0: # Price went up while holding profitable position - excellent
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct * 0.8 # Reduced from 1.5 to reduce noise
|
||||
elif abs(price_change_pct) < movement_threshold: # Price stable - good
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.5 # Reduced reward to reduce noise
|
||||
else: # Price dropped while holding profitable position - still okay but less reward
|
||||
was_correct = True
|
||||
directional_accuracy = max(0, (current_position_pnl / 100.0) - abs(price_change_pct) * 0.3)
|
||||
if position_side == "LONG":
|
||||
# For LONG positions: HOLD is good if price goes up or stays stable
|
||||
if price_change_pct > 0: # Price went up - excellent hold
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct * 0.8
|
||||
elif abs(price_change_pct) < movement_threshold: # Price stable - good hold
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.5
|
||||
else: # Price dropped - still okay but less reward
|
||||
was_correct = True
|
||||
directional_accuracy = max(0, (current_position_pnl / 100.0) - abs(price_change_pct) * 0.3)
|
||||
elif position_side == "SHORT":
|
||||
# For SHORT positions: HOLD is good if price goes down or stays stable
|
||||
if price_change_pct < 0: # Price went down - excellent hold
|
||||
was_correct = True
|
||||
directional_accuracy = abs(price_change_pct) * 0.8
|
||||
elif abs(price_change_pct) < movement_threshold: # Price stable - good hold
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.5
|
||||
else: # Price went up - still okay but less reward
|
||||
was_correct = True
|
||||
directional_accuracy = max(0, (current_position_pnl / 100.0) - abs(price_change_pct) * 0.3)
|
||||
else:
|
||||
# Unknown position side - fallback to general logic
|
||||
if abs(price_change_pct) < movement_threshold:
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.4
|
||||
else:
|
||||
was_correct = False
|
||||
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.5
|
||||
|
||||
elif current_position_pnl < 0: # Currently losing position
|
||||
# Holding a losing position is generally bad - should consider closing
|
||||
if price_change_pct > movement_threshold: # Price recovered - good hold
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct * 0.6 # Reduced reward
|
||||
else: # Price continued down or stayed flat - bad hold
|
||||
was_correct = False
|
||||
# Penalty proportional to loss magnitude
|
||||
directional_accuracy = abs(current_position_pnl / 100.0) * 0.3 # Reduced penalty
|
||||
if position_side == "LONG":
|
||||
# For LONG positions: HOLD is good if price recovers (goes up)
|
||||
if price_change_pct > movement_threshold: # Price recovered - good hold
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct * 0.6
|
||||
else: # Price continued down or stayed flat - bad hold
|
||||
was_correct = False
|
||||
directional_accuracy = abs(current_position_pnl / 100.0) * 0.3
|
||||
elif position_side == "SHORT":
|
||||
# For SHORT positions: HOLD is good if price recovers (goes down)
|
||||
if price_change_pct < -movement_threshold: # Price recovered - good hold
|
||||
was_correct = True
|
||||
directional_accuracy = abs(price_change_pct) * 0.6
|
||||
else: # Price continued up or stayed flat - bad hold
|
||||
was_correct = False
|
||||
directional_accuracy = abs(current_position_pnl / 100.0) * 0.3
|
||||
else:
|
||||
# Unknown position side - fallback to general logic
|
||||
if abs(price_change_pct) > movement_threshold:
|
||||
was_correct = True
|
||||
directional_accuracy = abs(price_change_pct) * 0.6
|
||||
else:
|
||||
was_correct = False
|
||||
directional_accuracy = abs(current_position_pnl / 100.0) * 0.3
|
||||
|
||||
else: # Breakeven position
|
||||
# Standard HOLD evaluation for breakeven positions
|
||||
if abs(price_change_pct) < movement_threshold: # Price stable - good
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.4 # Reduced reward
|
||||
directional_accuracy = movement_threshold * 0.4
|
||||
else: # Price moved significantly - missed opportunity
|
||||
was_correct = False
|
||||
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.5
|
||||
else:
|
||||
# If we don't have a position, HOLD is correct if price stayed relatively stable
|
||||
was_correct = abs(price_change_pct) < movement_threshold
|
||||
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.4 # Reduced reward
|
||||
# If we don't have a position, HOLD should be skipped from accuracy calculations and training
|
||||
# No position + HOLD = NEUTRAL (no action taken, no profit/loss)
|
||||
was_correct = None # Not applicable
|
||||
should_skip = True # Skip from accuracy calculations and training
|
||||
directional_accuracy = 0.0 # No reward/penalty for neutral action
|
||||
# Force reward to 0.0 for NEUTRAL actions
|
||||
final_reward = 0.0
|
||||
return final_reward, was_correct, should_skip
|
||||
|
||||
# Calculate FEE-AWARE magnitude-based multiplier (aggressive rewards for profitable movements)
|
||||
abs_movement = abs(price_change_pct)
|
||||
@ -3877,7 +4057,7 @@ class TradingOrchestrator:
|
||||
# Clamp reward to reasonable range
|
||||
final_reward = max(-5.0, min(5.0, final_reward))
|
||||
|
||||
return final_reward, was_correct
|
||||
return final_reward, was_correct, should_skip
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating sophisticated reward: {e}")
|
||||
@ -3887,14 +4067,23 @@ class TradingOrchestrator:
|
||||
if predicted_action == "HOLD" and has_position:
|
||||
# If holding a position, HOLD is correct if price didn't drop significantly
|
||||
simple_correct = price_change_pct > -0.2 # Allow small losses while holding
|
||||
should_skip = False
|
||||
elif predicted_action == "HOLD" and not has_position:
|
||||
# No position + HOLD = NEUTRAL, should be skipped
|
||||
simple_correct = None
|
||||
should_skip = True
|
||||
# Force reward to 0.0 for NEUTRAL actions
|
||||
return 0.0, simple_correct, should_skip
|
||||
else:
|
||||
# Standard evaluation for other cases
|
||||
simple_correct = (
|
||||
(predicted_action == "BUY" and price_change_pct > 0.1)
|
||||
or (predicted_action == "SELL" and price_change_pct < -0.1)
|
||||
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
|
||||
)
|
||||
return (1.0 if simple_correct else -0.5, simple_correct)
|
||||
should_skip = False
|
||||
|
||||
simple_reward = 1.0 if simple_correct else -0.5 if simple_correct is not None else 0.0
|
||||
return simple_reward, simple_correct, should_skip
|
||||
|
||||
def _calculate_price_vector_loss(
|
||||
self,
|
||||
@ -4138,7 +4327,7 @@ class TradingOrchestrator:
|
||||
# Extract price vector from record if available
|
||||
predicted_price_vector = record.get("price_direction") or record.get("predicted_price_vector")
|
||||
|
||||
sophisticated_reward, _ = self._calculate_sophisticated_reward(
|
||||
sophisticated_reward, _, should_skip = self._calculate_sophisticated_reward(
|
||||
record.get("action", "HOLD"),
|
||||
record.get("confidence", 0.5),
|
||||
price_change_pct,
|
||||
@ -4149,6 +4338,11 @@ class TradingOrchestrator:
|
||||
predicted_price_vector=predicted_price_vector,
|
||||
)
|
||||
|
||||
# Skip training if this is a neutral action (no position + HOLD)
|
||||
if should_skip:
|
||||
logger.debug(f"Skipping training for neutral action: {record.get('action', 'HOLD')} (no position)")
|
||||
return
|
||||
|
||||
# Calculate price vector training loss if we have vector predictions
|
||||
if predicted_price_vector:
|
||||
vector_loss = self._calculate_price_vector_loss(
|
||||
@ -4289,7 +4483,7 @@ class TradingOrchestrator:
|
||||
async def _train_rl_model(
|
||||
self, model, model_name: str, model_input, prediction: Dict, reward: float
|
||||
) -> bool:
|
||||
"""Train RL model (DQN) with experience replay"""
|
||||
memory_size = 0 # Ensure memory_size is always defined
|
||||
try:
|
||||
# Convert prediction action to action index
|
||||
action_names = ["SELL", "HOLD", "BUY"]
|
||||
@ -4346,54 +4540,62 @@ class TradingOrchestrator:
|
||||
next_state=state, # Simplified - using same state
|
||||
done=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Added experience to {model_name}: action={prediction['action']}, reward={reward:.3f}"
|
||||
)
|
||||
logger.info(f"RL EXPERIENCE ADDED to {model_name.upper()}:")
|
||||
logger.info(f" Action: {prediction['action']} (index: {action_idx})")
|
||||
logger.info(f" Reward: {reward:.3f}")
|
||||
logger.info(f" State shape: {state.shape}")
|
||||
logger.info(f" Memory size: {memory_size}")
|
||||
|
||||
# Trigger training if enough experiences
|
||||
memory_size = len(getattr(model, "memory", []))
|
||||
batch_size = getattr(model, "batch_size", 32)
|
||||
if memory_size >= batch_size:
|
||||
logger.debug(
|
||||
f"Training {model_name} with {memory_size} experiences"
|
||||
self.training_logger.info(f"RL TRAINING STARTED for {model_name.upper()}:")
|
||||
self.training_logger.info(f" Experiences: {memory_size}")
|
||||
self.training_logger.info(f" Batch size: {batch_size}")
|
||||
self.training_logger.info(f" Action: {prediction['action']}")
|
||||
self.training_logger.info(f" Reward: {reward:.3f}")
|
||||
|
||||
# Ensure model is in training mode
|
||||
if hasattr(model, "policy_net"):
|
||||
model.policy_net.train()
|
||||
|
||||
training_start_time = time.time()
|
||||
training_loss = model.replay()
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_loss is not None and training_loss > 0:
|
||||
self.update_model_loss(model_name, training_loss)
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_loss, training_duration_ms
|
||||
)
|
||||
|
||||
# Ensure model is in training mode
|
||||
if hasattr(model, "policy_net"):
|
||||
model.policy_net.train()
|
||||
|
||||
training_start_time = time.time()
|
||||
training_loss = model.replay()
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_loss is not None and training_loss > 0:
|
||||
self.update_model_loss(model_name, training_loss)
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_loss, training_duration_ms
|
||||
)
|
||||
logger.debug(
|
||||
f"RL training completed for {model_name}: loss={training_loss:.4f}, time={training_duration_ms:.1f}ms"
|
||||
)
|
||||
return True
|
||||
elif training_loss == 0.0:
|
||||
logger.warning(
|
||||
f"RL training returned zero loss for {model_name} - possible gradient issue"
|
||||
)
|
||||
# Still update training statistics
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_duration_ms=training_duration_ms
|
||||
)
|
||||
return False # Training failed
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_duration_ms=training_duration_ms
|
||||
)
|
||||
self.training_logger.info(f"RL TRAINING COMPLETED for {model_name.upper()}:")
|
||||
self.training_logger.info(f" Loss: {training_loss:.4f}")
|
||||
self.training_logger.info(f" Training time: {training_duration_ms:.1f}ms")
|
||||
self.training_logger.info(f" Experiences used: {memory_size}")
|
||||
self.training_logger.info(f" Action: {prediction['action']}")
|
||||
self.training_logger.info(f" Reward: {reward:.3f}")
|
||||
self.training_logger.info(f" State shape: {state.shape}")
|
||||
return True
|
||||
elif training_loss == 0.0:
|
||||
logger.warning(
|
||||
f"RL training returned zero loss for {model_name} - possible gradient issue"
|
||||
)
|
||||
# Still update training statistics
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_duration_ms=training_duration_ms
|
||||
)
|
||||
return False # Training failed
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not enough experiences for {model_name}: {memory_size}/{batch_size}"
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_duration_ms=training_duration_ms
|
||||
)
|
||||
return True # Experience added successfully, training will happen later
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not enough experiences for {model_name}: {memory_size}/{batch_size}"
|
||||
)
|
||||
return True # Experience added successfully, training will happen later
|
||||
|
||||
return False
|
||||
|
||||
@ -4675,9 +4877,22 @@ class TradingOrchestrator:
|
||||
model_name, current_loss, training_duration_ms
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"CNN direct training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms"
|
||||
self.training_logger.info(
|
||||
f"CNN DIRECT TRAINING COMPLETED:"
|
||||
)
|
||||
self.training_logger.info(f" Model: {model_name}")
|
||||
self.training_logger.info(f" Loss: {current_loss:.4f}")
|
||||
self.training_logger.info(f" Training time: {training_duration_ms:.1f}ms")
|
||||
self.training_logger.info(f" Action: {actual_action}")
|
||||
self.training_logger.info(f" Reward: {reward:.4f}")
|
||||
self.training_logger.info(f" Symbol: {symbol}")
|
||||
|
||||
# Log detailed loss breakdown
|
||||
if 'price_direction_loss' in locals():
|
||||
self.training_logger.info(f" Price direction loss: {price_direction_loss:.4f}")
|
||||
if 'extrema_loss' in locals():
|
||||
logger.info(f" Extrema loss: {extrema_loss:.4f}")
|
||||
logger.info(f" Total loss: {total_loss:.4f}")
|
||||
|
||||
# Trigger long-term training on stored inference records
|
||||
if hasattr(self.cnn_model, "train_on_stored_records") and hasattr(self, "cnn_optimizer"):
|
||||
@ -4708,7 +4923,12 @@ class TradingOrchestrator:
|
||||
# Train on stored records
|
||||
long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=5)
|
||||
if long_term_loss > 0:
|
||||
logger.debug(f"CNN long-term training completed: loss={long_term_loss:.4f}")
|
||||
self.training_logger.info(f"CNN LONG-TERM TRAINING COMPLETED:")
|
||||
self.training_logger.info(f" Long-term loss: {long_term_loss:.4f}")
|
||||
self.training_logger.info(f" Records processed: {len(self.cnn_model.inference_records)}")
|
||||
self.training_logger.info(f" Price change: {price_change_pct:+.3f}%")
|
||||
self.training_logger.info(f" Current price: ${current_price:.2f}")
|
||||
self.training_logger.info(f" Inference price: ${inference_price:.2f}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in CNN long-term training: {e}")
|
||||
|
||||
@ -4722,9 +4942,10 @@ class TradingOrchestrator:
|
||||
symbol = record.get("symbol", "ETH/USDT")
|
||||
actual_action = prediction["action"]
|
||||
model.add_training_sample(symbol, actual_action, reward)
|
||||
logger.debug(
|
||||
f"Added training sample to {model_name}: action={actual_action}, reward={reward:.3f}"
|
||||
)
|
||||
logger.info(f"TRAINING SAMPLE ADDED to {model_name.upper()}:")
|
||||
logger.info(f" Action: {actual_action}")
|
||||
logger.info(f" Reward: {reward:.3f}")
|
||||
logger.info(f" Symbol: {symbol}")
|
||||
|
||||
# If model has train method, trigger training
|
||||
if hasattr(model, "train") and callable(getattr(model, "train")):
|
||||
@ -4741,9 +4962,30 @@ class TradingOrchestrator:
|
||||
self._update_model_training_statistics(
|
||||
model_name, current_loss, training_duration_ms
|
||||
)
|
||||
logger.debug(
|
||||
f"Model {model_name} training completed: loss={current_loss:.4f}"
|
||||
)
|
||||
self.training_logger.info(f"MODEL TRAINING COMPLETED for {model_name.upper()}:")
|
||||
self.training_logger.info(f" Loss: {current_loss:.4f}")
|
||||
self.training_logger.info(f" Training time: {training_duration_ms:.1f}ms")
|
||||
self.training_logger.info(f" Action: {actual_action}")
|
||||
self.training_logger.info(f" Reward: {reward:.3f}")
|
||||
self.training_logger.info(f" Symbol: {symbol}")
|
||||
|
||||
# Log additional training metrics if available
|
||||
if "accuracy" in training_results:
|
||||
self.training_logger.info(f" Accuracy: {training_results['accuracy']:.4f}")
|
||||
if "epochs" in training_results:
|
||||
self.training_logger.info(f" Epochs: {training_results['epochs']}")
|
||||
if "samples" in training_results:
|
||||
self.training_logger.info(f" Samples: {training_results['samples']}")
|
||||
|
||||
# Periodic comprehensive logging (every 10th training)
|
||||
if hasattr(self, '_training_count'):
|
||||
self._training_count += 1
|
||||
else:
|
||||
self._training_count = 1
|
||||
|
||||
if self._training_count % 10 == 0:
|
||||
self.training_logger.info("PERIODIC COMPREHENSIVE TRAINING LOG:")
|
||||
self.log_model_statistics(detailed=True)
|
||||
else:
|
||||
self._update_model_training_statistics(
|
||||
model_name, training_duration_ms=training_duration_ms
|
||||
@ -6901,7 +7143,6 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced training: {e}")
|
||||
return False
|
||||
|
||||
def get_enhanced_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get enhanced training system statistics with orchestrator integration"""
|
||||
try:
|
||||
@ -7169,6 +7410,19 @@ class TradingOrchestrator:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _get_position_side(self, symbol: str) -> Optional[str]:
|
||||
"""Get the side of the current position (LONG/SHORT) or None if no position"""
|
||||
try:
|
||||
if self.trading_executor and hasattr(
|
||||
self.trading_executor, "get_current_position"
|
||||
):
|
||||
position = self.trading_executor.get_current_position(symbol)
|
||||
if position and position.get("size", 0) > 0:
|
||||
return position.get("side", "LONG").upper()
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def _calculate_position_enhanced_reward_for_dqn(self, base_reward, action, position_pnl, has_position):
|
||||
@ -8263,7 +8517,6 @@ class TradingOrchestrator:
|
||||
status[data_type][symbol] = len(queue)
|
||||
|
||||
return status
|
||||
|
||||
def get_detailed_queue_status(self) -> Dict[str, Any]:
|
||||
"""Get detailed status of all data queues with timestamps and data info"""
|
||||
detailed_status = {}
|
||||
@ -8819,7 +9072,6 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error populating initial queue data: {e}")
|
||||
|
||||
def _try_fallback_data_strategy(
|
||||
self, symbol: str, missing_data: List[Tuple[str, int, int]]
|
||||
) -> bool:
|
||||
"""
|
||||
@ -8919,4 +9171,4 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback data strategy: {e}")
|
||||
return False
|
||||
return False
|
@ -21,9 +21,17 @@
|
||||
"training_enabled": true
|
||||
},
|
||||
"dqn_agent": {
|
||||
"inference_enabled": "inference_enabled",
|
||||
"inference_enabled": false,
|
||||
"training_enabled": false
|
||||
},
|
||||
"enhanced_cnn": {
|
||||
"inference_enabled": true,
|
||||
"training_enabled": false
|
||||
},
|
||||
"cob_rl_model": {
|
||||
"inference_enabled": false,
|
||||
"training_enabled": false
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-07-30T09:19:11.731827"
|
||||
"timestamp": "2025-07-30T11:07:48.287272"
|
||||
}
|
@ -38,21 +38,15 @@ class SafeFormatter(logging.Formatter):
|
||||
|
||||
class SafeStreamHandler(logging.StreamHandler):
|
||||
"""Stream handler that forces UTF-8 encoding where supported"""
|
||||
|
||||
def __init__(self, stream=None):
|
||||
super().__init__(stream)
|
||||
# Try to set UTF-8 encoding on stdout/stderr if supported
|
||||
if hasattr(self.stream, 'reconfigure'):
|
||||
try:
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, use errors='ignore'
|
||||
self.stream.reconfigure(encoding='utf-8', errors='ignore')
|
||||
else:
|
||||
# On Unix-like systems, use backslashreplace
|
||||
self.stream.reconfigure(encoding='utf-8', errors='backslashreplace')
|
||||
except (AttributeError, OSError):
|
||||
# If reconfigure is not available or fails, continue silently
|
||||
pass
|
||||
if platform.system() == "Windows":
|
||||
# Force UTF-8 encoding on Windows
|
||||
if hasattr(stream, 'reconfigure'):
|
||||
try:
|
||||
stream.reconfigure(encoding='utf-8', errors='ignore')
|
||||
except:
|
||||
pass
|
||||
|
||||
def setup_safe_logging(log_level=logging.INFO, log_file='logs/safe_logging.log'):
|
||||
"""Setup logging with SafeFormatter and UTF-8 encoding with enhanced persistence
|
||||
@ -165,3 +159,69 @@ def setup_safe_logging(log_level=logging.INFO, log_file='logs/safe_logging.log')
|
||||
# Register atexit handler for normal shutdown
|
||||
atexit.register(flush_all_logs)
|
||||
|
||||
def setup_training_logger(log_level=logging.INFO, log_file='logs/training.log'):
|
||||
"""Setup a separate training logger that writes to training.log
|
||||
|
||||
Args:
|
||||
log_level: Logging level (default: INFO)
|
||||
log_file: Path to training log file (default: logs/training.log)
|
||||
|
||||
Returns:
|
||||
logging.Logger: The training logger instance
|
||||
"""
|
||||
# Ensure logs directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create training logger
|
||||
training_logger = logging.getLogger('training')
|
||||
training_logger.setLevel(log_level)
|
||||
|
||||
# Clear existing handlers to avoid duplicates
|
||||
for handler in training_logger.handlers[:]:
|
||||
training_logger.removeHandler(handler)
|
||||
|
||||
# Create file handler for training logs
|
||||
try:
|
||||
encoding_kwargs = {
|
||||
"encoding": "utf-8",
|
||||
"errors": "ignore" if platform.system() == "Windows" else "backslashreplace"
|
||||
}
|
||||
|
||||
from logging.handlers import RotatingFileHandler
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10*1024*1024, # 10MB max file size
|
||||
backupCount=5, # Keep 5 backup files
|
||||
**encoding_kwargs
|
||||
)
|
||||
file_handler.setFormatter(SafeFormatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
))
|
||||
|
||||
# Force immediate flush for training logs
|
||||
class FlushingHandler(RotatingFileHandler):
|
||||
def emit(self, record):
|
||||
super().emit(record)
|
||||
self.flush() # Force flush after each log
|
||||
|
||||
file_handler = FlushingHandler(
|
||||
log_file,
|
||||
maxBytes=10*1024*1024,
|
||||
backupCount=5,
|
||||
**encoding_kwargs
|
||||
)
|
||||
file_handler.setFormatter(SafeFormatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
))
|
||||
|
||||
training_logger.addHandler(file_handler)
|
||||
|
||||
except (OSError, IOError) as e:
|
||||
print(f"Warning: Could not create training log file {log_file}: {e}", file=sys.stderr)
|
||||
|
||||
# Prevent propagation to root logger to avoid duplicate logs
|
||||
training_logger.propagate = False
|
||||
|
||||
return training_logger
|
||||
|
||||
|
@ -543,8 +543,7 @@ class CleanTradingDashboard:
|
||||
success = True
|
||||
|
||||
if success:
|
||||
# Create callbacks for the new model
|
||||
self._create_model_toggle_callbacks(model_name)
|
||||
# Universal callback system handles new models automatically
|
||||
logger.info(f"✅ Successfully added model dynamically: {model_name}")
|
||||
return True
|
||||
else:
|
||||
@ -839,9 +838,9 @@ class CleanTradingDashboard:
|
||||
|
||||
logger.info(f"Setting up universal callbacks for {len(available_models)} models: {list(available_models.keys())}")
|
||||
|
||||
# Create callbacks for each model dynamically
|
||||
for model_name in available_models.keys():
|
||||
self._create_model_toggle_callbacks(model_name)
|
||||
# Universal callback system handles all models automatically
|
||||
# No need to create individual callbacks for each model
|
||||
logger.info(f"Universal callback system will handle {len(available_models)} models automatically")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up universal model callbacks: {e}")
|
||||
@ -903,79 +902,7 @@ class CleanTradingDashboard:
|
||||
'transformer': {'name': 'transformer', 'type': 'fallback'}
|
||||
}
|
||||
|
||||
def _create_model_toggle_callbacks(self, model_name):
|
||||
"""Create inference and training toggle callbacks for a specific model"""
|
||||
try:
|
||||
# Create inference toggle callback
|
||||
@self.app.callback(
|
||||
Output(f'{model_name}-inference-toggle', 'value'),
|
||||
[Input(f'{model_name}-inference-toggle', 'value')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def update_model_inference_toggle(value):
|
||||
return self._handle_model_toggle(model_name, 'inference', value)
|
||||
|
||||
# Create training toggle callback
|
||||
@self.app.callback(
|
||||
Output(f'{model_name}-training-toggle', 'value'),
|
||||
[Input(f'{model_name}-training-toggle', 'value')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def update_model_training_toggle(value):
|
||||
return self._handle_model_toggle(model_name, 'training', value)
|
||||
|
||||
logger.debug(f"Created toggle callbacks for model: {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating callbacks for model {model_name}: {e}")
|
||||
|
||||
def _handle_model_toggle(self, model_name, toggle_type, value):
|
||||
"""Universal handler for model toggle changes"""
|
||||
try:
|
||||
enabled = bool(value and len(value) > 0) # Convert list to boolean
|
||||
|
||||
if self.orchestrator:
|
||||
# Map component model name back to orchestrator's expected model name
|
||||
reverse_mapping = {
|
||||
'dqn': 'dqn_agent',
|
||||
'cnn': 'enhanced_cnn',
|
||||
'decision_fusion': 'decision',
|
||||
'extrema_trainer': 'extrema_trainer',
|
||||
'cob_rl': 'cob_rl',
|
||||
'transformer': 'transformer'
|
||||
}
|
||||
|
||||
orchestrator_model_name = reverse_mapping.get(model_name, model_name)
|
||||
|
||||
# Update orchestrator toggle state
|
||||
if toggle_type == 'inference':
|
||||
self.orchestrator.set_model_toggle_state(orchestrator_model_name, inference_enabled=enabled)
|
||||
elif toggle_type == 'training':
|
||||
self.orchestrator.set_model_toggle_state(orchestrator_model_name, training_enabled=enabled)
|
||||
|
||||
logger.info(f"Model {model_name} ({orchestrator_model_name}) {toggle_type} toggle: {enabled}")
|
||||
|
||||
# Update dashboard state variables for backward compatibility
|
||||
self._update_dashboard_state_variable(model_name, toggle_type, enabled)
|
||||
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling toggle for {model_name} {toggle_type}: {e}")
|
||||
return value
|
||||
|
||||
def _update_dashboard_state_variable(self, model_name, toggle_type, enabled):
|
||||
"""Update dashboard state variables for dynamic model management"""
|
||||
try:
|
||||
# Store in dynamic model toggle states
|
||||
if model_name not in self.model_toggle_states:
|
||||
self.model_toggle_states[model_name] = {"inference_enabled": True, "training_enabled": True}
|
||||
|
||||
self.model_toggle_states[model_name][f"{toggle_type}_enabled"] = enabled
|
||||
logger.debug(f"Updated dynamic model state: {model_name}.{toggle_type}_enabled = {enabled}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating dynamic model state: {e}")
|
||||
# Dynamic callback functions removed - using universal callback system instead
|
||||
|
||||
def _setup_callbacks(self):
|
||||
"""Setup dashboard callbacks"""
|
||||
@ -1439,11 +1366,19 @@ class CleanTradingDashboard:
|
||||
}
|
||||
|
||||
orchestrator_name = model_mapping.get(model_name, model_name)
|
||||
self.orchestrator.set_model_toggle_state(
|
||||
orchestrator_name,
|
||||
toggle_type + '_enabled',
|
||||
is_enabled
|
||||
)
|
||||
|
||||
# Call set_model_toggle_state with correct parameters based on toggle type
|
||||
if toggle_type == 'inference':
|
||||
self.orchestrator.set_model_toggle_state(
|
||||
orchestrator_name,
|
||||
inference_enabled=is_enabled
|
||||
)
|
||||
elif toggle_type == 'training':
|
||||
self.orchestrator.set_model_toggle_state(
|
||||
orchestrator_name,
|
||||
training_enabled=is_enabled
|
||||
)
|
||||
|
||||
logger.info(f"Updated {orchestrator_name} {toggle_type}_enabled = {is_enabled}")
|
||||
|
||||
# Return all current values (no change needed)
|
||||
@ -10066,7 +10001,9 @@ def create_clean_dashboard(data_provider: Optional[DataProvider] = None, orchest
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info("Received shutdown signal")
|
||||
self.shutdown() # Assuming a shutdown method exists or add one
|
||||
# Graceful shutdown - just exit
|
||||
import sys
|
||||
sys.exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Only set signal handlers if we're in the main thread
|
||||
|
@ -513,29 +513,47 @@ class ModelsTrainingPanel:
|
||||
)
|
||||
], style={"flex": "1"}),
|
||||
|
||||
# Toggle switches with pattern matching IDs
|
||||
# Interactive toggles for inference and training
|
||||
html.Div([
|
||||
# Inference toggle
|
||||
html.Div([
|
||||
html.Label("Inf", className="text-muted small me-1", style={"font-size": "10px"}),
|
||||
dcc.Checklist(
|
||||
html.Label("Inf", className="text-muted", style={
|
||||
"font-size": "9px",
|
||||
"margin-bottom": "0",
|
||||
"margin-right": "3px",
|
||||
"font-weight": "500"
|
||||
}),
|
||||
dbc.Switch(
|
||||
id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'inference'},
|
||||
options=[{"label": "", "value": True}],
|
||||
value=[True] if model_data.get('inference_enabled', True) else [],
|
||||
className="form-check-input me-2",
|
||||
style={"transform": "scale(0.7)"}
|
||||
value=['enabled'] if model_data.get('inference_enabled', True) else [],
|
||||
className="model-toggle-switch",
|
||||
style={
|
||||
"transform": "scale(0.6)",
|
||||
"margin": "0",
|
||||
"padding": "0"
|
||||
}
|
||||
)
|
||||
], className="d-flex align-items-center me-2"),
|
||||
], className="d-flex align-items-center me-2", style={"height": "18px"}),
|
||||
# Training toggle
|
||||
html.Div([
|
||||
html.Label("Trn", className="text-muted small me-1", style={"font-size": "10px"}),
|
||||
dcc.Checklist(
|
||||
html.Label("Trn", className="text-muted", style={
|
||||
"font-size": "9px",
|
||||
"margin-bottom": "0",
|
||||
"margin-right": "3px",
|
||||
"font-weight": "500"
|
||||
}),
|
||||
dbc.Switch(
|
||||
id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'training'},
|
||||
options=[{"label": "", "value": True}],
|
||||
value=[True] if model_data.get('training_enabled', True) else [],
|
||||
className="form-check-input",
|
||||
style={"transform": "scale(0.7)"}
|
||||
value=['enabled'] if model_data.get('training_enabled', True) else [],
|
||||
className="model-toggle-switch",
|
||||
style={
|
||||
"transform": "scale(0.6)",
|
||||
"margin": "0",
|
||||
"padding": "0"
|
||||
}
|
||||
)
|
||||
], className="d-flex align-items-center")
|
||||
], className="d-flex")
|
||||
], className="d-flex align-items-center", style={"height": "18px"})
|
||||
], className="d-flex align-items-center", style={"gap": "8px"})
|
||||
], className="d-flex align-items-center mb-2"),
|
||||
|
||||
# Model metrics
|
||||
|
Reference in New Issue
Block a user