try to fix input dimentions

This commit is contained in:
Dobromir Popov
2025-07-13 23:41:47 +03:00
parent bcc13a5db3
commit ebf65494a8
7 changed files with 358 additions and 145 deletions

View File

@ -454,6 +454,13 @@ class DQNAgent:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.policy_net = self.policy_net.to(device)
self.target_net = self.target_net.to(device)
return self
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""