cb ws
This commit is contained in:
@ -271,10 +271,16 @@ class DQNAgent:
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
logger.info(f"DQN Agent using device: {self.device}")
|
||||
|
||||
# Initialize models with RL-specific network architecture
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
|
||||
# Ensure models are on the correct device
|
||||
self.policy_net = self.policy_net.to(self.device)
|
||||
self.target_net = self.target_net.to(self.device)
|
||||
|
||||
# Initialize the target network with the same weights as the policy network
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
@ -997,11 +1003,19 @@ class DQNAgent:
|
||||
|
||||
# Convert to tensors with proper validation
|
||||
try:
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
# Ensure all data is on CPU first, then move to device
|
||||
states_array = np.array(states, dtype=np.float32)
|
||||
actions_array = np.array(actions, dtype=np.int64)
|
||||
rewards_array = np.array(rewards, dtype=np.float32)
|
||||
next_states_array = np.array(next_states, dtype=np.float32)
|
||||
dones_array = np.array(dones, dtype=np.float32)
|
||||
|
||||
# Convert to tensors and move to device
|
||||
states = torch.from_numpy(states_array).to(self.device)
|
||||
actions = torch.from_numpy(actions_array).to(self.device)
|
||||
rewards = torch.from_numpy(rewards_array).to(self.device)
|
||||
next_states = torch.from_numpy(next_states_array).to(self.device)
|
||||
dones = torch.from_numpy(dones_array).to(self.device)
|
||||
|
||||
# Final validation of tensor shapes
|
||||
if states.shape[0] == 0 or actions.shape[0] == 0:
|
||||
|
Reference in New Issue
Block a user