fixed model training

This commit is contained in:
Dobromir Popov
2025-11-12 18:11:19 +02:00
parent 7cb4201bc0
commit 8354aec830
6 changed files with 9 additions and 9 deletions

View File

@@ -1229,7 +1229,7 @@ class TradingTransformerTrainer:
for k, v in batch.items()}
# Use automatic mixed precision (FP16) for memory efficiency
with torch.cuda.amp.autocast(enabled=self.use_amp):
with torch.amp.autocast('cuda', enabled=self.use_amp):
# Forward pass with multi-timeframe data
outputs = self.model(
price_data_1s=batch.get('price_data_1s'),

View File

@@ -323,7 +323,7 @@ class COBRLModelInterface(ModelInterface):
self.optimizer.zero_grad()
if self.scaler:
with torch.cuda.amp.autocast():
with torch.amp.autocast('cuda'):
outputs = self.model(features)
loss = self._calculate_loss(outputs, targets)

View File

@@ -1436,7 +1436,7 @@ class DQNAgent:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
with torch.cuda.amp.autocast():
with torch.amp.autocast('cuda'):
# Get current Q values and predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)