wip training
This commit is contained in:
@@ -1229,21 +1229,30 @@ class TradingTransformerTrainer:
|
||||
if not is_accumulation_step or self.current_accumulation_step == 1:
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Move batch to device and DELETE original CPU tensors to prevent memory leak
|
||||
# CRITICAL: Store original keys to delete CPU tensors after moving to GPU
|
||||
batch_gpu = {}
|
||||
for k, v in batch.items():
|
||||
# OPTIMIZATION: Only move batch to device if not already there
|
||||
# Check if first tensor is already on correct device
|
||||
needs_transfer = False
|
||||
for v in batch.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move to device (creates GPU copy)
|
||||
batch_gpu[k] = v.to(self.device, non_blocking=True)
|
||||
# Delete CPU tensor immediately to free memory
|
||||
del batch[k]
|
||||
else:
|
||||
batch_gpu[k] = v
|
||||
needs_transfer = (v.device != self.device)
|
||||
break
|
||||
|
||||
# Replace batch with GPU version
|
||||
batch = batch_gpu
|
||||
del batch_gpu
|
||||
if needs_transfer:
|
||||
# Move batch to device and DELETE original CPU tensors to prevent memory leak
|
||||
batch_gpu = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move to device (creates GPU copy)
|
||||
batch_gpu[k] = v.to(self.device, non_blocking=True)
|
||||
# Delete CPU tensor immediately to free memory
|
||||
del batch[k]
|
||||
else:
|
||||
batch_gpu[k] = v
|
||||
|
||||
# Replace batch with GPU version
|
||||
batch = batch_gpu
|
||||
del batch_gpu
|
||||
# else: batch is already on GPU, use it directly!
|
||||
|
||||
# Use automatic mixed precision (FP16) for memory efficiency
|
||||
# Support both CUDA and ROCm (AMD) devices
|
||||
|
||||
Reference in New Issue
Block a user