try to fix input dimentions

This commit is contained in:
Dobromir Popov
2025-07-13 23:41:47 +03:00
parent bcc13a5db3
commit ebf65494a8
7 changed files with 358 additions and 145 deletions

View File

@ -498,10 +498,20 @@ class EnhancedCNN(nn.Module):
"""Enhanced action selection with ultra massive model predictions"""
if explore and np.random.random() < 0.1: # 10% random exploration
return np.random.choice(self.n_actions)
self.eval()
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Accept both NumPy arrays and already-built torch tensors
if isinstance(state, torch.Tensor):
state_tensor = state.detach().to(self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
else:
# Convert to tensor **directly on the target device** to avoid intermediate CPU copies
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
with torch.no_grad():
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)