diff --git a/NN/models/enhanced_cnn.py b/NN/models/enhanced_cnn.py index 2ed55b4..15bcd59 100644 --- a/NN/models/enhanced_cnn.py +++ b/NN/models/enhanced_cnn.py @@ -494,11 +494,8 @@ class EnhancedCNN(nn.Module): return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions - def act(self, state, explore=True): + def act(self, state, explore=True) -> Tuple[int, float, List[float]]: """Enhanced action selection with ultra massive model predictions""" - if explore and np.random.random() < 0.1: # 10% random exploration - return np.random.choice(self.n_actions) - self.eval() # Accept both NumPy arrays and already-built torch tensors @@ -511,14 +508,16 @@ class EnhancedCNN(nn.Module): state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device) if state_tensor.dim() == 1: state_tensor = state_tensor.unsqueeze(0) - + with torch.no_grad(): q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor) # Apply softmax to get action probabilities - action_probs = torch.softmax(q_values, dim=1) - action = torch.argmax(action_probs, dim=1).item() - + action_probs_tensor = torch.softmax(q_values, dim=1) + action_idx = int(torch.argmax(action_probs_tensor, dim=1).item()) + confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action + action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return + # Log advanced predictions for better decision making if hasattr(self, '_log_predictions') and self._log_predictions: # Log volatility prediction @@ -547,7 +546,7 @@ class EnhancedCNN(nn.Module): logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})") logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})") - return action + return action_idx, confidence, action_probs def save(self, path): """Save model weights and architecture""" diff --git a/core/orchestrator.py b/core/orchestrator.py index 76188a7..571d5ee 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -33,7 +33,7 @@ from .data_provider import DataProvider from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface -from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file +from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface # Import COB integration for real-time market microstructure data @@ -435,15 +435,21 @@ class TradingOrchestrator: super().__init__(name) self.model = model - def predict(self, data): + def predict(self, data=None): try: - # Use available methods from ExtremaTrainer - if hasattr(self.model, 'detect_extrema'): - return self.model.detect_extrema(data) - elif hasattr(self.model, 'get_pivot_signals'): - return self.model.get_pivot_signals(data) - # Return a default prediction if no methods available - return {'action': 'HOLD', 'confidence': 0.5} + # ExtremaTrainer provides context features, not a direct prediction + # We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model + if not isinstance(data, str): + logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.") + return None + + features = self.model.get_context_features_for_model(symbol=data) + if features is not None and features.size > 0: + # The presence of features indicates a signal. We'll return a generic HOLD + # with a neutral confidence. This can be refined if ExtremaTrainer provides + # more specific BUY/SELL signals directly. + return {'action': 'HOLD', 'confidence': 0.5, 'probabilities': {'BUY': 0.33, 'SELL': 0.33, 'HOLD': 0.34}} + return None except Exception as e: logger.error(f"Error in extrema trainer prediction: {e}") return None @@ -1016,7 +1022,14 @@ class TradingOrchestrator: if hasattr(model.model, 'act'): # Flatten / reshape enhanced_features as needed… x = self._prepare_cnn_input(enhanced_features) - action_probs, confidence = model.model.act(x, explore=False) + + # Debugging: Print the type and content of x before passing to act() + logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...") + + action_idx, confidence, action_probs = model.model.act(x, explore=False) + + # Debugging: Print the type and content of the unpacked values + logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})") else: # fallback to generic predict result = model.predict(enhanced_features) @@ -1029,7 +1042,7 @@ class TradingOrchestrator: logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}") continue # skip this timeframe entirely - # 4) If we still don’t have valid probs, skip + # 4) If we still don't have valid probs, skip if action_probs is None: continue @@ -1175,11 +1188,15 @@ class TradingOrchestrator: action_probs = prediction_result.get('probabilities', None) confidence = prediction_result.get('confidence', 0.7) else: - # Assume it's just action probabilities + # Assume it's just action probabilities (e.g., a list or numpy array) action_probs = prediction_result confidence = 0.7 # Default confidence if action_probs is not None: + # Ensure action_probs is a numpy array for argmax + if not isinstance(action_probs, np.ndarray): + action_probs = np.array(action_probs) + action_names = ['SELL', 'HOLD', 'BUY'] best_action_idx = np.argmax(action_probs) best_action = action_names[best_action_idx]