remove dummy data, improve training , follow architecture

This commit is contained in:
Dobromir Popov
2025-07-04 23:51:35 +03:00
parent e8b9c05148
commit ce8c00a9d1
13 changed files with 435 additions and 838 deletions

View File

@ -451,7 +451,13 @@ class DQNAgent:
state_tensor = state.unsqueeze(0).to(self.device)
# Get Q-values
q_values = self.policy_net(state_tensor)
policy_output = self.policy_net(state_tensor)
if isinstance(policy_output, dict):
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
elif isinstance(policy_output, tuple):
q_values = policy_output[0] # Assume first element is Q-values
else:
q_values = policy_output
action_values = q_values.cpu().data.numpy()[0]
# Calculate confidence scores