This commit is contained in:
Dobromir Popov
2025-07-14 17:56:09 +03:00
parent d53a2ba75d
commit 4a55c5ff03
11 changed files with 1509 additions and 118 deletions

View File

@ -737,6 +737,44 @@ class DQNAgent:
return None
def _safe_cnn_forward(self, network, states):
"""Safely call CNN forward method ensuring we always get 5 return values"""
try:
result = network(states)
if isinstance(result, tuple) and len(result) == 5:
return result
elif isinstance(result, tuple) and len(result) == 1:
# Handle case where only q_values are returned (like in empty tensor case)
q_values = result[0]
batch_size = q_values.size(0)
device = q_values.device
default_extrema = torch.zeros(batch_size, 3, device=device)
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return q_values, default_extrema, default_price, default_features, default_advanced
else:
# Fallback: create all default tensors
batch_size = states.size(0)
device = states.device
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
default_extrema = torch.zeros(batch_size, 3, device=device)
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced
except Exception as e:
logger.error(f"Error in CNN forward pass: {e}")
# Fallback: create all default tensors
batch_size = states.size(0)
device = states.device
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
default_extrema = torch.zeros(batch_size, 3, device=device)
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced
def replay(self, experiences=None):
"""Train the model using experiences from memory"""
@ -995,17 +1033,17 @@ class DQNAgent:
else:
raise ValueError("Invalid arguments to _replay_standard")
# Get current Q values
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
# Get current Q values using safe wrapper
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Enhanced Double DQN implementation
with torch.no_grad():
if self.use_double_dqn:
# Double DQN: Use policy network to select actions, target network to evaluate
policy_q_values, _, _, _, _ = self.policy_net(next_states)
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
next_actions = policy_q_values.argmax(1)
target_q_values_all, _, _, _, _ = self.target_net(next_states)
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN: Use target network for both selection and evaluation