From 060fdd28b44c7e860d22e94123627112df492354 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 8 Sep 2025 12:13:50 +0300 Subject: [PATCH] enable training --- core/orchestrator.py | 118 +++---------------------------------------- main_clean.py | 2 +- 2 files changed, 8 insertions(+), 112 deletions(-) diff --git a/core/orchestrator.py b/core/orchestrator.py index 1433c55..8c8b4ab 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -1330,110 +1330,6 @@ class TradingOrchestrator: logger.debug(f"Error building COB state for {symbol}: {e}") return None - def _combine_predictions(self, symbol: str, price: float, predictions: List[Prediction], - timestamp: datetime) -> TradingDecision: - # Call act_with_confidence and handle different return formats - result = model.model.act_with_confidence(state) - - if len(result) == 3: - # EnhancedCNN format: (action, confidence, q_values) - action_idx, confidence, raw_q_values = result - elif len(result) == 2: - # DQN format: (action, confidence) - action_idx, confidence = result - raw_q_values = None - else: - logger.warning(f"Unexpected result format from RL model: {result}") - return None - else: - # Fallback to standard act method - action_idx = model.model.act(state) - confidence = 0.6 # Default confidence - raw_q_values = None - - # Convert action index to action name - action_names = ['BUY', 'SELL', 'HOLD'] - if 0 <= action_idx < len(action_names): - action = action_names[action_idx] - else: - logger.warning(f"Invalid action index from RL model: {action_idx}") - return None - - # Store prediction in database for tracking - if (hasattr(self, 'enhanced_training_system') and - self.enhanced_training_system and - hasattr(self.enhanced_training_system, 'store_model_prediction')): - - current_price = self._get_current_price_safe(symbol) - if current_price > 0: - prediction_id = self.enhanced_training_system.store_model_prediction( - model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN", - symbol=symbol, - prediction_type=action, - confidence=confidence, - current_price=current_price - ) - logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}") - - # Create prediction object - prediction = Prediction( - model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN", - symbol=symbol, - signal=action, - confidence=confidence, - reasoning=f"DQN agent prediction with Q-values: {raw_q_values}", - features=state.tolist() if isinstance(state, np.ndarray) else [], - metadata={ - 'action_idx': action_idx, - 'q_values': raw_q_values.tolist() if raw_q_values is not None else None, - 'state_size': len(state) if state is not None else 0 - } - ) - - return prediction - - except Exception as e: - logger.error(f"Error getting RL prediction for {symbol}: {e}") - return None - raw_q_values = None # No raw q_values from simple act - else: - logger.error(f"RL model {model.name} has no act method") - return None - - action_names = ['SELL', 'HOLD', 'BUY'] - action = action_names[action_idx] - - # Convert raw_q_values to list if they are a tensor - q_values_for_capture = None - if raw_q_values is not None and hasattr(raw_q_values, 'tolist'): - q_values_for_capture = raw_q_values.tolist() - elif raw_q_values is not None and isinstance(raw_q_values, list): - q_values_for_capture = raw_q_values - - # Create prediction object - prediction = Prediction( - action=action, - confidence=float(confidence), - # Use actual q_values if available, otherwise default probabilities - probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))}, - timeframe='mixed', # RL uses mixed timeframes - timestamp=datetime.now(), - model_name=model.name, - metadata={'state_size': len(state)} - ) - - # Capture DQN prediction for dashboard visualization - current_price = self._get_current_price(symbol) - if current_price: - # Only pass q_values if they exist, otherwise pass empty list - q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else [] - self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass) - - return prediction - - except Exception as e: - logger.error(f"Error getting RL prediction: {e}") - return None async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]: """Get prediction from generic model""" @@ -1747,7 +1643,7 @@ class TradingOrchestrator: ) if needs_refresh: - result = load_best_checkpoint(model_name) + result = load_best_checkpoint(model_name) self._checkpoint_cache[model_name] = result self._checkpoint_cache_time[model_name] = current_time @@ -1892,15 +1788,15 @@ class TradingOrchestrator: # Initialize enhanced training system directly (no external training_integration module needed) try: from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem - - self.enhanced_training_system = EnhancedRealtimeTrainingSystem( - orchestrator=self, - data_provider=self.data_provider, + + self.enhanced_training_system = EnhancedRealtimeTrainingSystem( + orchestrator=self, + data_provider=self.data_provider, dashboard=None ) - + logger.info("✅ Enhanced training system initialized successfully") - + # Auto-start training by default logger.info("🚀 Auto-starting enhanced real-time training...") self.start_enhanced_training() diff --git a/main_clean.py b/main_clean.py index 3965ef2..93cc506 100644 --- a/main_clean.py +++ b/main_clean.py @@ -33,7 +33,7 @@ def create_safe_orchestrator() -> Optional[TradingOrchestrator]: try: # Create orchestrator with basic configuration (uses correct constructor parameters) orchestrator = TradingOrchestrator( - enhanced_rl_training=False # Disable problematic training initially + enhanced_rl_training=True # Enable RL training for model improvement ) logger.info("Trading orchestrator created successfully")