fix: Main Problem: Batch Corruption Across Epochs
This commit is contained in:
@@ -1355,27 +1355,29 @@ class TradingTransformerTrainer:
|
||||
needs_transfer = (v.device != self.device)
|
||||
break
|
||||
|
||||
# Always create a new batch_on_device dict to avoid modifying the input batch
|
||||
# This is critical for multi-epoch training where batches are reused
|
||||
batch_on_device = {}
|
||||
|
||||
if needs_transfer:
|
||||
# Move batch to device - iterate over copy of keys to avoid modification during iteration
|
||||
batch_gpu = {}
|
||||
for k in list(batch.keys()): # Create list copy to avoid modification during iteration
|
||||
# Move batch to device - create new tensors
|
||||
for k in list(batch.keys()):
|
||||
v = batch[k]
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move to device (creates GPU copy)
|
||||
batch_gpu[k] = v.to(self.device, non_blocking=True)
|
||||
batch_on_device[k] = v.to(self.device, non_blocking=True)
|
||||
else:
|
||||
batch_gpu[k] = v
|
||||
|
||||
# Replace batch with GPU version
|
||||
batch = batch_gpu
|
||||
del batch_gpu
|
||||
# else: batch is already on GPU, use it directly!
|
||||
batch_on_device[k] = v
|
||||
else:
|
||||
# Batch is already on GPU, but still create a copy of the dict
|
||||
# to avoid modifying the original batch dict
|
||||
for k, v in batch.items():
|
||||
batch_on_device[k] = v
|
||||
|
||||
# Ensure all batch tensors are on the same device as the model
|
||||
# This is critical to avoid device mismatch errors
|
||||
model_device = next(self.model.parameters()).device
|
||||
batch_on_device = {}
|
||||
for k, v in batch.items():
|
||||
for k, v in list(batch_on_device.items()):
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move tensor to model's device if it's not already there
|
||||
if v.device != model_device:
|
||||
|
||||
Reference in New Issue
Block a user