deribit
This commit is contained in:
@ -737,6 +737,44 @@ class DQNAgent:
|
||||
|
||||
return None
|
||||
|
||||
def _safe_cnn_forward(self, network, states):
|
||||
"""Safely call CNN forward method ensuring we always get 5 return values"""
|
||||
try:
|
||||
result = network(states)
|
||||
if isinstance(result, tuple) and len(result) == 5:
|
||||
return result
|
||||
elif isinstance(result, tuple) and len(result) == 1:
|
||||
# Handle case where only q_values are returned (like in empty tensor case)
|
||||
q_values = result[0]
|
||||
batch_size = q_values.size(0)
|
||||
device = q_values.device
|
||||
default_extrema = torch.zeros(batch_size, 3, device=device)
|
||||
default_price = torch.zeros(batch_size, 1, device=device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||
return q_values, default_extrema, default_price, default_features, default_advanced
|
||||
else:
|
||||
# Fallback: create all default tensors
|
||||
batch_size = states.size(0)
|
||||
device = states.device
|
||||
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
|
||||
default_extrema = torch.zeros(batch_size, 3, device=device)
|
||||
default_price = torch.zeros(batch_size, 1, device=device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN forward pass: {e}")
|
||||
# Fallback: create all default tensors
|
||||
batch_size = states.size(0)
|
||||
device = states.device
|
||||
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
|
||||
default_extrema = torch.zeros(batch_size, 3, device=device)
|
||||
default_price = torch.zeros(batch_size, 1, device=device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
||||
|
||||
def replay(self, experiences=None):
|
||||
"""Train the model using experiences from memory"""
|
||||
|
||||
@ -995,17 +1033,17 @@ class DQNAgent:
|
||||
else:
|
||||
raise ValueError("Invalid arguments to _replay_standard")
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
# Get current Q values using safe wrapper
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Enhanced Double DQN implementation
|
||||
with torch.no_grad():
|
||||
if self.use_double_dqn:
|
||||
# Double DQN: Use policy network to select actions, target network to evaluate
|
||||
policy_q_values, _, _, _, _ = self.policy_net(next_states)
|
||||
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
||||
next_actions = policy_q_values.argmax(1)
|
||||
target_q_values_all, _, _, _, _ = self.target_net(next_states)
|
||||
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN: Use target network for both selection and evaluation
|
||||
|
Reference in New Issue
Block a user