From c5a9e75ee7ef51e86f6e16f22a563a1da5350d47 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 30 Jul 2025 00:32:35 +0300 Subject: [PATCH] vector predictions inference fix --- NN/models/dqn_agent.py | 41 ++++++++++++++++++++++++----------- NN/models/standardized_cnn.py | 2 +- core/orchestrator.py | 1 + data/ui_state.json | 4 ++-- web/clean_dashboard.py | 2 +- 5 files changed, 33 insertions(+), 17 deletions(-) diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 815858e..4a66194 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -169,7 +169,12 @@ class DQNNetwork(nn.Module): # Combine value and advantage for Q-values q_values = value + advantage - advantage.mean(dim=1, keepdim=True) - return q_values, regime_pred, price_direction_pred, volatility_pred, features + # Add placeholder multi-timeframe predictions for compatibility + batch_size = q_values.size(0) + device = q_values.device + multi_timeframe_pred = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each + + return q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred def act(self, state, explore=True): """ @@ -197,7 +202,7 @@ class DQNNetwork(nn.Module): state = state.unsqueeze(0) with torch.no_grad(): - q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state) + q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.forward(state) # Price direction predictions are processed in the agent's act method # This is just the network forward pass @@ -781,7 +786,7 @@ class DQNAgent: # Process price direction predictions from the network # Get the raw predictions from the network's forward pass with torch.no_grad(): - q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state) + q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state) if price_direction_pred is not None: self.process_price_direction_predictions(price_direction_pred) @@ -826,7 +831,7 @@ class DQNAgent: # Get network outputs with torch.no_grad(): - q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor) + q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state_tensor) # Process price direction predictions if price_direction_pred is not None: @@ -1025,11 +1030,18 @@ class DQNAgent: return None def _safe_cnn_forward(self, network, states): - """Safely call CNN forward method ensuring we always get 5 return values""" + """Safely call CNN forward method ensuring we always get 6 return values""" try: result = network(states) - if isinstance(result, tuple) and len(result) == 5: + if isinstance(result, tuple) and len(result) == 6: return result + elif isinstance(result, tuple) and len(result) == 5: + # Handle legacy 5-value return by adding default multi_timeframe_pred + q_values, extrema_pred, price_pred, features, advanced_pred = result + batch_size = q_values.size(0) + device = q_values.device + default_multi_timeframe = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each + return q_values, extrema_pred, price_pred, features, advanced_pred, default_multi_timeframe elif isinstance(result, tuple) and len(result) == 1: # Handle case where only q_values are returned (like in empty tensor case) q_values = result[0] @@ -1039,7 +1051,8 @@ class DQNAgent: default_price = torch.zeros(batch_size, 1, device=device) default_features = torch.zeros(batch_size, 1024, device=device) default_advanced = torch.zeros(batch_size, 1, device=device) - return q_values, default_extrema, default_price, default_features, default_advanced + default_multi_timeframe = torch.zeros(batch_size, 12, device=device) + return q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe else: # Fallback: create all default tensors batch_size = states.size(0) @@ -1049,7 +1062,8 @@ class DQNAgent: default_price = torch.zeros(batch_size, 1, device=device) default_features = torch.zeros(batch_size, 1024, device=device) default_advanced = torch.zeros(batch_size, 1, device=device) - return default_q_values, default_extrema, default_price, default_features, default_advanced + default_multi_timeframe = torch.zeros(batch_size, 12, device=device) + return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe except Exception as e: logger.error(f"Error in CNN forward pass: {e}") # Fallback: create all default tensors @@ -1060,7 +1074,8 @@ class DQNAgent: default_price = torch.zeros(batch_size, 1, device=device) default_features = torch.zeros(batch_size, 1024, device=device) default_advanced = torch.zeros(batch_size, 1, device=device) - return default_q_values, default_extrema, default_price, default_features, default_advanced + default_multi_timeframe = torch.zeros(batch_size, 12, device=device) + return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe def replay(self, experiences=None): """Train the model using experiences from memory""" @@ -1437,20 +1452,20 @@ class DQNAgent: warnings.simplefilter("ignore", FutureWarning) with torch.cuda.amp.autocast(): # Get current Q values and predictions - current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states) + current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred, current_multi_timeframe_pred = self._safe_cnn_forward(self.policy_net, states) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Get next Q values from target network with torch.no_grad(): if self.use_double_dqn: # Double DQN - policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states) + policy_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states) next_actions = policy_q_values.argmax(1) - target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states) + target_q_values_all, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states) next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1) else: # Standard DQN - next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states) + next_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states) next_q_values = next_q_values.max(1)[0] # Ensure consistent shapes diff --git a/NN/models/standardized_cnn.py b/NN/models/standardized_cnn.py index f9ceb05..1e6adb4 100644 --- a/NN/models/standardized_cnn.py +++ b/NN/models/standardized_cnn.py @@ -162,7 +162,7 @@ class StandardizedCNN(nn.Module): cnn_input = processed_features.unsqueeze(1) # Add sequence dimension try: - q_values, extrema_pred, price_pred, cnn_features, advanced_pred = self.enhanced_cnn(cnn_input) + q_values, extrema_pred, price_pred, cnn_features, advanced_pred, multi_timeframe_pred = self.enhanced_cnn(cnn_input) except Exception as e: logger.warning(f"Enhanced CNN forward pass failed: {e}, using fallback") # Fallback to direct processing diff --git a/core/orchestrator.py b/core/orchestrator.py index 204aa8e..1ab1962 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -4841,6 +4841,7 @@ class TradingOrchestrator: price_pred, features_refined, advanced_pred, + multi_timeframe_pred, ) = self.cnn_model(features_tensor) # Convert to probabilities using softmax diff --git a/data/ui_state.json b/data/ui_state.json index 3b167b4..8348fc1 100644 --- a/data/ui_state.json +++ b/data/ui_state.json @@ -14,7 +14,7 @@ }, "decision_fusion": { "inference_enabled": false, - "training_enabled": true + "training_enabled": false }, "transformer": { "inference_enabled": false, @@ -25,5 +25,5 @@ "training_enabled": true } }, - "timestamp": "2025-07-29T23:33:51.882579" + "timestamp": "2025-07-30T00:17:57.738273" } \ No newline at end of file diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index ccaf1ab..b2d9698 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -7266,7 +7266,7 @@ class CleanTradingDashboard: # Get prediction from CNN model with torch.no_grad(): - q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_adapter(features_tensor) + q_values, extrema_pred, price_pred, features_refined, advanced_pred, multi_timeframe_pred = self.cnn_adapter(features_tensor) # Convert to probabilities using softmax action_probs = torch.softmax(q_values, dim=1)