memory leak fixes

This commit is contained in:
Dobromir Popov
2025-11-13 16:05:15 +02:00
parent 13b6fafaf8
commit b0b24f36b2
2 changed files with 52 additions and 24 deletions

View File

@@ -1772,6 +1772,10 @@ class RealTrainingAdapter:
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
# CRITICAL: Delete checkpoint immediately to free memory
del checkpoint
gc.collect()
else:
logger.info(" No previous checkpoint found, starting fresh")
except Exception as e:
@@ -1785,28 +1789,35 @@ class RealTrainingAdapter:
import torch
# MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
def batch_generator():
"""Generate batches on-the-fly to avoid memory accumulation"""
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# Repeat based on repetitions parameter
repetitions = data.get('repetitions', 1)
for _ in range(repetitions):
# Yield batch directly without storing
yield batch
else:
logger.warning(f" Failed to convert sample {i+1}")
# MEMORY FIX: Pre-convert batches ONCE and cache them
# This avoids recreating batches every epoch (major leak!)
logger.info(" Pre-converting batches (one-time operation)...")
cached_batches = []
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
cached_batches.append(batch)
else:
logger.warning(f" Failed to convert sample {i+1}")
# Count total batches for logging
total_batches = sum(data.get('repetitions', 1) for data in training_data
if self._convert_annotation_to_transformer_batch(data) is not None)
# Clear training_data to free memory
training_data.clear()
del training_data
gc.collect()
logger.info(f" Converted {len(cached_batches)} batches, cleared source data")
def batch_generator():
"""Yield pre-converted batches (no recreation)"""
for batch in cached_batches:
yield batch
total_batches = len(cached_batches)
if total_batches == 0:
raise Exception("No valid training batches after conversion")
logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
logger.info(f" Ready to train on {total_batches} batches")
# MEMORY FIX: Process batches directly from generator, no grouping needed
# Batch size of 1 (single sample) to avoid OOM
@@ -1874,18 +1885,22 @@ class RealTrainingAdapter:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# MEMORY FIX: Explicit cleanup after EVERY batch
del batch
del result
# Don't delete batch (it's from cache, reused)
if 'result' in locals():
del result
if torch.cuda.is_available():
torch.cuda.empty_cache()
# After optimizer step, aggressive cleanup
# After optimizer step, aggressive cleanup + memory check
if not is_accumulation_step:
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
# CRITICAL: Check memory limit
memory_usage = memory_guard.check_memory(raise_on_limit=True)
except torch.cuda.OutOfMemoryError as oom_error:
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
# Aggressive memory cleanup on OOM
@@ -1999,8 +2014,17 @@ class RealTrainingAdapter:
session.final_loss = session.current_loss
session.accuracy = avg_accuracy
# MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
# MEMORY FIX: Final cleanup
logger.info(" Final memory cleanup...")
# Clear cached batches
for batch in cached_batches:
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
cached_batches.clear()
del cached_batches
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()