fixed model training
This commit is contained in:
@@ -1229,7 +1229,7 @@ class TradingTransformerTrainer:
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Use automatic mixed precision (FP16) for memory efficiency
|
||||
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
# Forward pass with multi-timeframe data
|
||||
outputs = self.model(
|
||||
price_data_1s=batch.get('price_data_1s'),
|
||||
|
||||
Reference in New Issue
Block a user