try to work with ROCKm (AMD) GPUs again
This commit is contained in:
@@ -1229,7 +1229,9 @@ class TradingTransformerTrainer:
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Use automatic mixed precision (FP16) for memory efficiency
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
# Support both CUDA and ROCm (AMD) devices
|
||||
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
||||
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
|
||||
# Forward pass with multi-timeframe data
|
||||
outputs = self.model(
|
||||
price_data_1s=batch.get('price_data_1s'),
|
||||
|
||||
@@ -323,7 +323,8 @@ class COBRLModelInterface(ModelInterface):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.scaler:
|
||||
with torch.amp.autocast('cuda'):
|
||||
device_type = 'cuda' if next(self.model.parameters()).device.type == 'cuda' else 'cpu'
|
||||
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||
outputs = self.model(features)
|
||||
loss = self._calculate_loss(outputs, targets)
|
||||
|
||||
|
||||
@@ -1436,7 +1436,8 @@ class DQNAgent:
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
with torch.amp.autocast('cuda'):
|
||||
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
||||
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||
# Get current Q values and predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user