LR training wip
This commit is contained in:
@@ -2530,13 +2530,20 @@ class RealTrainingAdapter:
|
||||
OPTIMIZATION: Batches are already on GPU and grouped for efficient processing.
|
||||
Each mini-batch contains 5 samples for better GPU utilization.
|
||||
|
||||
IMPORTANT: Creates a shallow copy of batch dict to prevent in-place modifications
|
||||
from affecting subsequent epochs. Tensors themselves are shared (not copied).
|
||||
CRITICAL FIX: Clone tensors for each epoch to avoid autograd version conflicts.
|
||||
When the same tensor is used across multiple forward passes, operations like
|
||||
.contiguous() and .view() modify the tensor's version number, breaking backprop.
|
||||
"""
|
||||
for batch in grouped_batches:
|
||||
# Create shallow copy of batch dict to prevent modifications
|
||||
# Tensors are shared (not cloned) for memory efficiency
|
||||
batch_copy = {k: v for k, v in batch.items()}
|
||||
# CRITICAL: Clone all tensors to avoid version conflicts across epochs
|
||||
# This prevents "modified by an inplace operation" errors during backward pass
|
||||
batch_copy = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Clone tensor to create independent copy with fresh version number
|
||||
batch_copy[k] = v.clone()
|
||||
else:
|
||||
batch_copy[k] = v
|
||||
yield batch_copy
|
||||
|
||||
total_batches = len(grouped_batches)
|
||||
|
||||
Reference in New Issue
Block a user