cleanup, CNN fixes

This commit is contained in:
Dobromir Popov
2025-07-05 00:12:40 +03:00
parent ce8c00a9d1
commit 5ca7493708
18 changed files with 587 additions and 5181 deletions

View File

@ -461,6 +461,10 @@ class DQNAgent:
action_values = q_values.cpu().data.numpy()[0]
# Calculate confidence scores
# Ensure q_values has correct shape for softmax
if q_values.dim() == 1:
q_values = q_values.unsqueeze(0)
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
buy_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
@ -486,6 +490,10 @@ class DQNAgent:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.policy_net(state_tensor)
# Ensure q_values has correct shape for softmax
if q_values.dim() == 1:
q_values = q_values.unsqueeze(0)
# Convert Q-values to probabilities
action_probs = torch.softmax(q_values, dim=1)
action = q_values.argmax().item()