try to work with ROCKm (AMD) GPUs again

This commit is contained in:
Dobromir Popov
2025-11-12 18:12:47 +02:00
parent 8354aec830
commit 4a5c3fc943
6 changed files with 16 additions and 10 deletions

View File

@@ -323,7 +323,8 @@ class COBRLModelInterface(ModelInterface):
self.optimizer.zero_grad()
if self.scaler:
with torch.amp.autocast('cuda'):
device_type = 'cuda' if next(self.model.parameters()).device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
outputs = self.model(features)
loss = self._calculate_loss(outputs, targets)