From e1e453c204caab6418036b00c63fdb2e9c02443a Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 29 Jul 2025 00:09:13 +0300 Subject: [PATCH] dqn model data fix --- NN/models/dqn_agent.py | 2 +- core/orchestrator.py | 160 ++++++++++++++++++++++------------------- 2 files changed, 88 insertions(+), 74 deletions(-) diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 7feeb6a..ac5a3d2 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -1167,7 +1167,7 @@ class DQNAgent: # Handle empty dictionary case if not state: - logger.error("No numerical values found in state dict, using default state") + logger.error("Empty state dictionary received, using default state") expected_size = getattr(self, 'state_size', 403) if isinstance(expected_size, tuple): expected_size = np.prod(expected_size) diff --git a/core/orchestrator.py b/core/orchestrator.py index cc46031..dc1a124 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -6348,9 +6348,10 @@ class TradingOrchestrator: action = decision.action confidence = decision.confidence - # Get current market data for training context - market_data = self._get_current_market_data(symbol) - if not market_data: + # Get current market data for training context - use same data source as CNN model + base_data = self.build_base_data_input(symbol) + if not base_data: + logger.warning(f"No base data available for training {symbol}, skipping model training") return # Track if any model was trained for checkpoint saving @@ -6359,8 +6360,8 @@ class TradingOrchestrator: # Train DQN agent if available and enabled if self.rl_agent and hasattr(self.rl_agent, "add_experience") and self.is_model_training_enabled("dqn"): try: - # Create state representation - state = self._create_state_for_training(symbol, market_data) + # Create state representation from base_data (same as CNN model) + state = self._create_state_from_base_data(symbol, base_data) # Map action to DQN action space - CONSISTENT ACTION MAPPING action_mapping = {"BUY": 0, "SELL": 1, "HOLD": 2} @@ -6389,9 +6390,9 @@ class TradingOrchestrator: # Train CNN model if available and enabled if self.cnn_model and hasattr(self.cnn_model, "add_training_sample") and self.is_model_training_enabled("cnn"): try: - # Create CNN input features - cnn_features = self._create_cnn_features_for_training( - symbol, market_data + # Create CNN input features from base_data (same as inference) + cnn_features = self._create_cnn_features_from_base_data( + symbol, base_data ) # Create target based on action @@ -6573,88 +6574,101 @@ class TradingOrchestrator: def _get_current_market_data(self, symbol: str) -> Optional[Dict]: """Get current market data for training context""" try: - if self.data_provider: - # Get recent data for training - df = self.data_provider.get_historical_data(symbol, "1m", limit=100) - if df is not None and not df.empty: - return { - "ohlcv": df.tail(50).to_dict("records"), # Last 50 candles - "current_price": float(df["close"].iloc[-1]), - "volume": float(df["volume"].iloc[-1]), - "timestamp": df.index[-1], - } - return None + if not self.data_provider: + logger.warning(f"No data provider available for {symbol}") + return None + + # Get recent data for training + df = self.data_provider.get_historical_data(symbol, "1m", limit=100) + if df is not None and not df.empty: + return { + "ohlcv": df.tail(50).to_dict("records"), # Last 50 candles + "current_price": float(df["close"].iloc[-1]), + "volume": float(df["volume"].iloc[-1]), + "timestamp": df.index[-1], + } + else: + logger.warning(f"No historical data available for {symbol}") + return None except Exception as e: - logger.debug(f"Error getting market data for training: {e}") + logger.error(f"Error getting market data for training {symbol}: {e}") return None - def _create_state_for_training(self, symbol: str, market_data: Dict) -> np.ndarray: - """Create state representation for DQN training""" + def _create_state_from_base_data(self, symbol: str, base_data: Any) -> np.ndarray: + """Create state representation for DQN training from base_data (same as CNN model)""" try: - # Create a basic state representation - ohlcv_data = market_data.get("ohlcv", []) - if not ohlcv_data: - return np.zeros(100) # Default state size + # Validate base_data + if not base_data or not hasattr(base_data, 'get_feature_vector'): + logger.warning(f"Invalid base_data for {symbol}: {type(base_data)}") + return np.zeros(403) # Default state size for DQN + + # Get feature vector from base_data (same as CNN model) + features = base_data.get_feature_vector() + + if not features or len(features) == 0: + logger.warning(f"No features available from base_data for {symbol}, using default state") + return np.zeros(403) # Default state size for DQN - # Extract features from recent candles - features = [] - for candle in ohlcv_data[-20:]: # Last 20 candles - features.extend( - [ - candle.get("open", 0), - candle.get("high", 0), - candle.get("low", 0), - candle.get("close", 0), - candle.get("volume", 0), - ] - ) - - # Pad or truncate to expected size - state = np.array(features[:100]) - if len(state) < 100: - state = np.pad(state, (0, 100 - len(state)), "constant") + # Convert to numpy array + state = np.array(features, dtype=np.float32) + + # Ensure correct dimensions for DQN (403 features) + if len(state) != 403: + if len(state) < 403: + # Pad with zeros + padded_state = np.zeros(403, dtype=np.float32) + padded_state[:len(state)] = state + state = padded_state + else: + # Truncate + state = state[:403] return state except Exception as e: - logger.debug(f"Error creating state for training: {e}") - return np.zeros(100) + logger.error(f"Error creating state from base_data for {symbol}: {e}") + return np.zeros(403) # Default state size for DQN - def _create_cnn_features_for_training( - self, symbol: str, market_data: Dict + + + def _create_cnn_features_from_base_data( + self, symbol: str, base_data: Any ) -> np.ndarray: - """Create CNN features for training""" + """Create CNN features for training from base_data (same as inference)""" try: - # Similar to state creation but formatted for CNN - ohlcv_data = market_data.get("ohlcv", []) - if not ohlcv_data: - return np.zeros((1, 100)) + # Validate base_data + if not base_data or not hasattr(base_data, 'get_feature_vector'): + logger.warning(f"Invalid base_data for CNN training {symbol}: {type(base_data)}") + return np.zeros((1, 403)) # Default CNN input size + + # Get feature vector from base_data (same as CNN inference) + features = base_data.get_feature_vector() + + if not features or len(features) == 0: + logger.warning(f"No features available from base_data for CNN training {symbol}, using default") + return np.zeros((1, 403)) # Default CNN input size - # Create feature matrix - features = [] - for candle in ohlcv_data[-20:]: - features.extend( - [ - candle.get("open", 0), - candle.get("high", 0), - candle.get("low", 0), - candle.get("close", 0), - candle.get("volume", 0), - ] - ) - - # Reshape for CNN input - cnn_features = np.array(features[:100]).reshape(1, -1) - if cnn_features.shape[1] < 100: - cnn_features = np.pad( - cnn_features, ((0, 0), (0, 100 - cnn_features.shape[1])), "constant" - ) + # Convert to numpy array and reshape for CNN + cnn_features = np.array(features, dtype=np.float32).reshape(1, -1) + + # Ensure correct dimensions for CNN (403 features) + if cnn_features.shape[1] != 403: + if cnn_features.shape[1] < 403: + # Pad with zeros + padded_features = np.zeros((1, 403), dtype=np.float32) + padded_features[0, :cnn_features.shape[1]] = cnn_features[0] + cnn_features = padded_features + else: + # Truncate + cnn_features = cnn_features[:, :403] return cnn_features except Exception as e: - logger.debug(f"Error creating CNN features for training: {e}") - return np.zeros((1, 100)) + logger.error(f"Error creating CNN features from base_data for {symbol}: {e}") + return np.zeros((1, 403)) # Default CNN input size + + def _create_cob_state_for_training(self, symbol: str, cob_data: Dict) -> np.ndarray: """Create COB state representation for training"""