try to work with ROCKm (AMD) GPUs again

This commit is contained in:
Dobromir Popov
2025-11-12 18:12:47 +02:00
parent 8354aec830
commit 4a5c3fc943
6 changed files with 16 additions and 10 deletions

View File

@@ -1436,7 +1436,8 @@ class DQNAgent:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
with torch.amp.autocast('cuda'):
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
# Get current Q values and predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)