From 36a8e256a85fd9b93b8f69db9c19fa724e255b02 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Sat, 26 Jul 2025 23:57:03 +0300 Subject: [PATCH] fix DQN RL inference, rebuild model --- NN/models/dqn_agent.py | 215 +++++++++++++++++++++++++++-------------- core/orchestrator.py | 84 +++++++++++----- 2 files changed, 201 insertions(+), 98 deletions(-) diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index f11c13c..ce7f61b 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -21,6 +21,112 @@ from utils.training_integration import get_training_integration # Configure logger logger = logging.getLogger(__name__) +class DQNNetwork(nn.Module): + """ + Deep Q-Network specifically designed for RL trading with unified BaseDataInput features + Handles 7850 input features from multi-timeframe, multi-asset data + """ + def __init__(self, input_dim: int, n_actions: int): + super(DQNNetwork, self).__init__() + + # Handle different input dimension formats + if isinstance(input_dim, (tuple, list)): + if len(input_dim) == 1: + self.input_size = input_dim[0] + else: + self.input_size = np.prod(input_dim) # Flatten multi-dimensional input + else: + self.input_size = input_dim + + self.n_actions = n_actions + + # Deep network architecture optimized for trading features + self.network = nn.Sequential( + # Input layer + nn.Linear(self.input_size, 2048), + nn.ReLU(), + nn.Dropout(0.3), + + # Hidden layers with residual-like connections + nn.Linear(2048, 1024), + nn.ReLU(), + nn.Dropout(0.3), + + nn.Linear(1024, 512), + nn.ReLU(), + nn.Dropout(0.3), + + nn.Linear(512, 256), + nn.ReLU(), + nn.Dropout(0.2), + + nn.Linear(256, 128), + nn.ReLU(), + nn.Dropout(0.2), + + # Output layer for Q-values + nn.Linear(128, n_actions) + ) + + # Initialize weights + self._initialize_weights() + + def _initialize_weights(self): + """Initialize network weights using Xavier initialization""" + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def forward(self, x): + """Forward pass through the network""" + # Ensure input is properly shaped + if x.dim() > 2: + x = x.view(x.size(0), -1) # Flatten if needed + elif x.dim() == 1: + x = x.unsqueeze(0) # Add batch dimension if needed + + return self.network(x) + + def act(self, state, explore=True): + """ + Select action using epsilon-greedy policy + + Args: + state: Current state (numpy array or tensor) + explore: Whether to use epsilon-greedy exploration + + Returns: + action_idx: Selected action index + confidence: Confidence score + action_probs: Action probabilities + """ + # Convert state to tensor if needed + if isinstance(state, np.ndarray): + state = torch.FloatTensor(state).to(next(self.parameters()).device) + + # Ensure proper shape + if state.dim() == 1: + state = state.unsqueeze(0) + + with torch.no_grad(): + q_values = self.forward(state) + + # Get action probabilities using softmax + action_probs = F.softmax(q_values, dim=1) + + # Select action (greedy for inference) + action_idx = torch.argmax(q_values, dim=1).item() + + # Calculate confidence as max probability + confidence = float(action_probs[0, action_idx].item()) + + # Convert probabilities to list + probs_list = action_probs.squeeze(0).cpu().numpy().tolist() + + return action_idx, confidence, probs_list + class DQNAgent: """ Deep Q-Network agent for trading @@ -80,12 +186,9 @@ class DQNAgent: else: self.device = device - # Initialize models with Enhanced CNN architecture for better performance - from NN.models.enhanced_cnn import EnhancedCNN - - # Use Enhanced CNN for both policy and target networks - self.policy_net = EnhancedCNN(self.state_dim, self.n_actions) - self.target_net = EnhancedCNN(self.state_dim, self.n_actions) + # Initialize models with RL-specific network architecture + self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device) + self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device) # Initialize the target network with the same weights as the policy network self.target_net.load_state_dict(self.policy_net.state_dict()) @@ -578,83 +681,45 @@ class DQNAgent: market_context: Additional market context for decision making Returns: - int: Action (0=BUY, 1=SELL, 2=HOLD) or None if should hold position + int: Action (0=BUY, 1=SELL) """ - - # Convert state to tensor - if isinstance(state, np.ndarray): - state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) - else: - state_tensor = state.unsqueeze(0).to(self.device) - - # Get Q-values - policy_output = self.policy_net(state_tensor) - if isinstance(policy_output, dict): - q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0])) - elif isinstance(policy_output, tuple): - q_values = policy_output[0] # Assume first element is Q-values - else: - q_values = policy_output - action_values = q_values.cpu().data.numpy()[0] - - # Calculate confidence scores - # Ensure q_values has correct shape for softmax - if q_values.dim() == 1: - q_values = q_values.unsqueeze(0) - - # FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD - buy_confidence = torch.softmax(q_values, dim=1)[0, 0].item() - sell_confidence = torch.softmax(q_values, dim=1)[0, 1].item() - - # Determine action based on current position and confidence thresholds - action = self._determine_action_with_position_management( - sell_confidence, buy_confidence, current_price, market_context, explore - ) - - # Update tracking - if current_price: - self.recent_prices.append(current_price) - - if action is not None: - self.recent_actions.append(action) - return action - else: - # Return 1 (HOLD) as a safe default if action is None + try: + # Use the DQNNetwork's act method for consistent behavior + action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore) + + # Apply epsilon-greedy exploration if requested + if explore and np.random.random() <= self.epsilon: + action_idx = np.random.choice(self.n_actions) + + # Update tracking + if current_price: + self.recent_prices.append(current_price) + + self.recent_actions.append(action_idx) + return action_idx + + except Exception as e: + logger.error(f"Error in act method: {e}") + # Return default action (HOLD/SELL) return 1 - def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]: - """Choose action with confidence score adapted to market regime (from Enhanced DQN)""" - with torch.no_grad(): - state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) - q_values = self.policy_net(state_tensor) - - # Handle case where network might return a tuple instead of tensor - if isinstance(q_values, tuple): - # If it's a tuple, take the first element (usually the main output) - q_values = q_values[0] - - # Ensure q_values is a tensor and has correct shape for softmax - if not hasattr(q_values, 'dim'): - logger.error(f"DQN: q_values is not a tensor: {type(q_values)}") - # Return default action with low confidence - return 1, 0.1 # Default to HOLD action - - if q_values.dim() == 1: - q_values = q_values.unsqueeze(0) - - # Convert Q-values to probabilities - action_probs = torch.softmax(q_values, dim=1) - action = q_values.argmax().item() - base_confidence = action_probs[0, action].item() + def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]: + """Choose action with confidence score adapted to market regime""" + try: + # Use the DQNNetwork's act method which handles the state properly + action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False) # Adapt confidence based on market regime regime_weight = self.market_regime_weights.get(market_regime, 1.0) adapted_confidence = min(base_confidence * regime_weight, 1.0) - # Always return int, float - if action is None: - return 1, 0.1 - return int(action), float(adapted_confidence) + # Return action, confidence, and probabilities (for orchestrator compatibility) + return int(action_idx), float(adapted_confidence), action_probs + + except Exception as e: + logger.error(f"Error in act_with_confidence: {e}") + # Return default action with low confidence + return 1, 0.1, [0.45, 0.55] # Default to HOLD action def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore): """ diff --git a/core/orchestrator.py b/core/orchestrator.py index 3f04be1..5f61925 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -283,9 +283,22 @@ class TradingOrchestrator: # Initialize DQN Agent try: from NN.models.dqn_agent import DQNAgent - state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features + + # Determine actual state size from BaseDataInput + try: + base_data = self.data_provider.build_base_data_input(self.symbol) + if base_data: + actual_state_size = len(base_data.get_feature_vector()) + logger.info(f"Detected actual state size: {actual_state_size}") + else: + actual_state_size = 7850 # Fallback based on error message + logger.warning(f"Could not determine state size, using fallback: {actual_state_size}") + except Exception as e: + actual_state_size = 7850 # Fallback based on error message + logger.warning(f"Error determining state size: {e}, using fallback: {actual_state_size}") + action_size = self.config.rl.get('action_space', 3) - self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size) + self.rl_agent = DQNAgent(state_shape=actual_state_size, n_actions=action_size) self.rl_agent.to(self.device) # Move DQN agent to the determined device # Load best checkpoint and capture initial state (using database metadata) @@ -306,7 +319,10 @@ class TradingOrchestrator: loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}" logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})") except Exception as e: - logger.warning(f"Error loading DQN checkpoint: {e}") + logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}") + logger.info("DQN will start fresh due to checkpoint incompatibility") + # Reset the agent to handle dimension mismatch + checkpoint_loaded = False if not checkpoint_loaded: # New model - no synthetic data, start fresh @@ -1144,7 +1160,10 @@ class TradingOrchestrator: # Collect input data for all models input_data = await self._collect_model_input_data(symbol) - + + # log all registered models + logger.debug(f"inferencing registered models: {self.model_registry.models}") + for model_name, model in self.model_registry.models.items(): try: prediction = None @@ -2058,8 +2077,17 @@ class TradingOrchestrator: ) predictions.append(prediction) - # Store prediction in queue for future use - self.update_data_queue('model_predictions', symbol, result) + # Store prediction in SQLite database for training + logger.debug(f"Added CNN prediction to database: {prediction}") + + # Store CNN prediction as inference record + await self._store_inference_data_async( + model_name="enhanced_cnn", + model_input=base_data, + prediction=prediction, + timestamp=datetime.now(), + symbol=symbol + ) except Exception as e: logger.error(f"Error using CNN adapter: {e}") @@ -2111,6 +2139,15 @@ class TradingOrchestrator: ) predictions.append(pred) + # Store CNN fallback prediction as inference record + await self._store_inference_data_async( + model_name=model.name, + model_input=base_data, + prediction=pred, + timestamp=datetime.now(), + symbol=symbol + ) + # Capture for dashboard current_price = self._get_current_price(symbol) if current_price is not None: @@ -2188,12 +2225,22 @@ class TradingOrchestrator: elif raw_q_values is not None and isinstance(raw_q_values, list): q_values_for_capture = raw_q_values - # Create prediction object + # Create prediction object with safe probability calculation + probabilities = {} + if q_values_for_capture and len(q_values_for_capture) == len(action_names): + # Use actual q_values if they match the expected length + probabilities = {action_names[i]: float(q_values_for_capture[i]) for i in range(len(action_names))} + else: + # Use default uniform probabilities if q_values are unavailable or mismatched + default_prob = 1.0 / len(action_names) + probabilities = {name: default_prob for name in action_names} + if q_values_for_capture: + logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.") + prediction = Prediction( action=action, confidence=float(confidence), - # Use actual q_values if available, otherwise default probabilities - probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))}, + probabilities=probabilities, timeframe='mixed', # RL uses mixed timeframes timestamp=datetime.now(), model_name=model.name, @@ -2279,7 +2326,7 @@ class TradingOrchestrator: return None def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]: - """Get current state for RL agent - ensure compatibility with saved model""" + """Get current state for RL agent using unified BaseDataInput""" try: # Use unified BaseDataInput approach base_data = self.build_base_data_input(symbol) @@ -2287,21 +2334,12 @@ class TradingOrchestrator: logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}") return None - # Get unified feature vector + # Get unified feature vector (7850 features including all timeframes and COB data) feature_vector = base_data.get_feature_vector() - # Ensure compatibility with saved model (expects 403 features) - target_size = 403 # Match the saved model's expected input size - if len(feature_vector) < target_size: - # Pad with zeros - padded_state = np.zeros(target_size) - padded_state[:len(feature_vector)] = feature_vector - return padded_state - elif len(feature_vector) > target_size: - # Truncate to target size - return feature_vector[:target_size] - else: - return feature_vector + # Return the full unified feature vector for RL agent + # The DQN agent is now initialized with the correct size to match this + return feature_vector except Exception as e: logger.error(f"Error creating RL state for {symbol}: {e}")