training fixes
This commit is contained in:
@ -494,11 +494,8 @@ class EnhancedCNN(nn.Module):
|
||||
|
||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||
|
||||
def act(self, state, explore=True):
|
||||
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||
return np.random.choice(self.n_actions)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Accept both NumPy arrays and already-built torch tensors
|
||||
@ -511,14 +508,16 @@ class EnhancedCNN(nn.Module):
|
||||
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||
|
||||
# Apply softmax to get action probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = torch.argmax(action_probs, dim=1).item()
|
||||
|
||||
action_probs_tensor = torch.softmax(q_values, dim=1)
|
||||
action_idx = int(torch.argmax(action_probs_tensor, dim=1).item())
|
||||
confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action
|
||||
action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return
|
||||
|
||||
# Log advanced predictions for better decision making
|
||||
if hasattr(self, '_log_predictions') and self._log_predictions:
|
||||
# Log volatility prediction
|
||||
@ -547,7 +546,7 @@ class EnhancedCNN(nn.Module):
|
||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
|
||||
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
|
||||
|
||||
return action
|
||||
return action_idx, confidence, action_probs
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
Reference in New Issue
Block a user