try to fix input dimentions
This commit is contained in:
@ -250,6 +250,12 @@ class COBRLModelInterface(ModelInterface):
|
||||
|
||||
logger.info(f"COB RL Model Interface initialized on {self.device}")
|
||||
|
||||
def to(self, device):
|
||||
"""PyTorch-style device movement method"""
|
||||
self.device = device
|
||||
self.model = self.model.to(device)
|
||||
return self
|
||||
|
||||
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Make prediction using the model"""
|
||||
self.model.eval()
|
||||
|
@ -454,6 +454,13 @@ class DQNAgent:
|
||||
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
||||
return False
|
||||
|
||||
def to(self, device):
|
||||
"""PyTorch-style device movement method"""
|
||||
self.device = device
|
||||
self.policy_net = self.policy_net.to(device)
|
||||
self.target_net = self.target_net.to(device)
|
||||
return self
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
||||
"""
|
||||
|
@ -498,10 +498,20 @@ class EnhancedCNN(nn.Module):
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||
return np.random.choice(self.n_actions)
|
||||
|
||||
|
||||
self.eval()
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
# Accept both NumPy arrays and already-built torch tensors
|
||||
if isinstance(state, torch.Tensor):
|
||||
state_tensor = state.detach().to(self.device)
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
else:
|
||||
# Convert to tensor **directly on the target device** to avoid intermediate CPU copies
|
||||
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||
|
||||
|
Reference in New Issue
Block a user