fix: Main Problem: Batch Corruption Across Epochs

This commit is contained in:
Dobromir Popov
2025-12-08 20:00:47 +02:00
parent cc555735e8
commit 81a7f27d2d
4 changed files with 205 additions and 15 deletions

View File

@@ -1355,27 +1355,29 @@ class TradingTransformerTrainer:
needs_transfer = (v.device != self.device)
break
# Always create a new batch_on_device dict to avoid modifying the input batch
# This is critical for multi-epoch training where batches are reused
batch_on_device = {}
if needs_transfer:
# Move batch to device - iterate over copy of keys to avoid modification during iteration
batch_gpu = {}
for k in list(batch.keys()): # Create list copy to avoid modification during iteration
# Move batch to device - create new tensors
for k in list(batch.keys()):
v = batch[k]
if isinstance(v, torch.Tensor):
# Move to device (creates GPU copy)
batch_gpu[k] = v.to(self.device, non_blocking=True)
batch_on_device[k] = v.to(self.device, non_blocking=True)
else:
batch_gpu[k] = v
# Replace batch with GPU version
batch = batch_gpu
del batch_gpu
# else: batch is already on GPU, use it directly!
batch_on_device[k] = v
else:
# Batch is already on GPU, but still create a copy of the dict
# to avoid modifying the original batch dict
for k, v in batch.items():
batch_on_device[k] = v
# Ensure all batch tensors are on the same device as the model
# This is critical to avoid device mismatch errors
model_device = next(self.model.parameters()).device
batch_on_device = {}
for k, v in batch.items():
for k, v in list(batch_on_device.items()):
if isinstance(v, torch.Tensor):
# Move tensor to model's device if it's not already there
if v.device != model_device: