try to work with ROCKm (AMD) GPUs again

This commit is contained in:
Dobromir Popov
2025-11-12 18:12:47 +02:00
parent 8354aec830
commit 4a5c3fc943
6 changed files with 16 additions and 10 deletions

View File

@@ -1229,7 +1229,9 @@ class TradingTransformerTrainer:
for k, v in batch.items()}
# Use automatic mixed precision (FP16) for memory efficiency
with torch.amp.autocast('cuda', enabled=self.use_amp):
# Support both CUDA and ROCm (AMD) devices
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
# Forward pass with multi-timeframe data
outputs = self.model(
price_data_1s=batch.get('price_data_1s'),