training fixes

This commit is contained in:
Dobromir Popov
2025-11-17 21:05:06 +02:00
parent 259ee9b14a
commit a8d59a946e
4 changed files with 119 additions and 60 deletions

View File

@@ -1238,14 +1238,13 @@ class TradingTransformerTrainer:
break
if needs_transfer:
# Move batch to device and DELETE original CPU tensors to prevent memory leak
# Move batch to device - iterate over copy of keys to avoid modification during iteration
batch_gpu = {}
for k, v in batch.items():
for k in list(batch.keys()): # Create list copy to avoid modification during iteration
v = batch[k]
if isinstance(v, torch.Tensor):
# Move to device (creates GPU copy)
batch_gpu[k] = v.to(self.device, non_blocking=True)
# Delete CPU tensor immediately to free memory
del batch[k]
else:
batch_gpu[k] = v