training fixes
This commit is contained in:
@@ -1238,14 +1238,13 @@ class TradingTransformerTrainer:
|
||||
break
|
||||
|
||||
if needs_transfer:
|
||||
# Move batch to device and DELETE original CPU tensors to prevent memory leak
|
||||
# Move batch to device - iterate over copy of keys to avoid modification during iteration
|
||||
batch_gpu = {}
|
||||
for k, v in batch.items():
|
||||
for k in list(batch.keys()): # Create list copy to avoid modification during iteration
|
||||
v = batch[k]
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move to device (creates GPU copy)
|
||||
batch_gpu[k] = v.to(self.device, non_blocking=True)
|
||||
# Delete CPU tensor immediately to free memory
|
||||
del batch[k]
|
||||
else:
|
||||
batch_gpu[k] = v
|
||||
|
||||
|
||||
Reference in New Issue
Block a user