try to work with ROCKm (AMD) GPUs again
This commit is contained in:
@@ -323,7 +323,8 @@ class COBRLModelInterface(ModelInterface):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.scaler:
|
||||
with torch.amp.autocast('cuda'):
|
||||
device_type = 'cuda' if next(self.model.parameters()).device.type == 'cuda' else 'cpu'
|
||||
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||
outputs = self.model(features)
|
||||
loss = self._calculate_loss(outputs, targets)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user