reduce memory usage

This commit is contained in:
Dobromir Popov
2025-11-13 17:34:31 +02:00
parent b0b24f36b2
commit 70e8ede8d3
3 changed files with 163 additions and 24 deletions

View File

@@ -1808,9 +1808,26 @@ class RealTrainingAdapter:
logger.info(f" Converted {len(cached_batches)} batches, cleared source data")
def batch_generator():
"""Yield pre-converted batches (no recreation)"""
"""
Yield pre-converted batches with proper memory management
CRITICAL: Each batch must be cloned and detached to prevent:
1. GPU memory accumulation across epochs
2. Computation graph retention
3. Version tracking issues
"""
for batch in cached_batches:
yield batch
# Clone and detach each tensor in the batch
# This creates a fresh copy without gradient history
cloned_batch = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
# detach() removes from computation graph
# clone() creates new memory (prevents aliasing)
cloned_batch[key] = value.detach().clone()
else:
cloned_batch[key] = value
yield cloned_batch
total_batches = len(cached_batches)
@@ -1823,9 +1840,18 @@ class RealTrainingAdapter:
# Batch size of 1 (single sample) to avoid OOM
logger.info(f" Processing batches individually (batch_size=1) for memory efficiency")
# MEMORY FIX: Train using generator with aggressive memory cleanup
# Reduced accumulation steps from 5 to 2 for less memory usage
accumulation_steps = 2 # Accumulate 2 batches before optimizer step
# MEMORY OPTIMIZATION: Configure gradient accumulation
# Process samples one at a time, accumulate gradients over multiple samples
# This reduces peak memory by ~50% compared to batching
accumulation_steps = max(2, min(5, total_batches)) # 2-5 steps based on data size
logger.info(f" Gradient accumulation: {accumulation_steps} steps")
logger.info(f" Effective batch size: {accumulation_steps} (processed as {accumulation_steps} × batch_size=1)")
# Configure trainer for gradient accumulation
if hasattr(trainer, 'set_gradient_accumulation_steps'):
trainer.set_gradient_accumulation_steps(accumulation_steps)
logger.info(f" Trainer configured for automatic gradient accumulation")
import gc
@@ -1840,14 +1866,16 @@ class RealTrainingAdapter:
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset gradient accumulation counter at start of epoch
if hasattr(trainer, 'reset_gradient_accumulation'):
trainer.reset_gradient_accumulation()
# Generate batches fresh for each epoch
for i, batch in enumerate(batch_generator()):
try:
# Determine if this is an accumulation step or optimizer step
is_accumulation_step = (i + 1) % accumulation_steps != 0
# Call the trainer's train_step method
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
# Trainer now handles gradient accumulation automatically
result = trainer.train_step(batch)
if result is not None:
# MEMORY FIX: Detach all tensor values to break computation graph
@@ -1885,18 +1913,29 @@ class RealTrainingAdapter:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# MEMORY FIX: Explicit cleanup after EVERY batch
# Don't delete batch (it's from cache, reused)
# Delete result dict to free memory
if 'result' in locals():
del result
# Delete the cloned batch (it's a fresh copy, safe to delete)
if 'batch' in locals():
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# Clear CUDA cache after every batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# After optimizer step, aggressive cleanup + memory check
if not is_accumulation_step:
# After optimizer step, aggressive cleanup
# Check if this was an optimizer step (not accumulation)
is_optimizer_step = ((i + 1) % accumulation_steps == 0)
if is_optimizer_step:
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
# CRITICAL: Check memory limit
memory_usage = memory_guard.check_memory(raise_on_limit=True)