fixed model training
This commit is contained in:
@@ -1436,7 +1436,7 @@ class DQNAgent:
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast('cuda'):
|
||||
# Get current Q values and predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user