try to fix input dimentions

This commit is contained in:
Dobromir Popov
2025-07-13 23:41:47 +03:00
parent bcc13a5db3
commit ebf65494a8
7 changed files with 358 additions and 145 deletions

View File

@ -250,6 +250,12 @@ class COBRLModelInterface(ModelInterface):
logger.info(f"COB RL Model Interface initialized on {self.device}")
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.model = self.model.to(device)
return self
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
"""Make prediction using the model"""
self.model.eval()

View File

@ -454,6 +454,13 @@ class DQNAgent:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.policy_net = self.policy_net.to(device)
self.target_net = self.target_net.to(device)
return self
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""

View File

@ -498,10 +498,20 @@ class EnhancedCNN(nn.Module):
"""Enhanced action selection with ultra massive model predictions"""
if explore and np.random.random() < 0.1: # 10% random exploration
return np.random.choice(self.n_actions)
self.eval()
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Accept both NumPy arrays and already-built torch tensors
if isinstance(state, torch.Tensor):
state_tensor = state.detach().to(self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
else:
# Convert to tensor **directly on the target device** to avoid intermediate CPU copies
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
with torch.no_grad():
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)