memory leak fixes
This commit is contained in:
@@ -1772,6 +1772,10 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
|
||||
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
|
||||
|
||||
# CRITICAL: Delete checkpoint immediately to free memory
|
||||
del checkpoint
|
||||
gc.collect()
|
||||
else:
|
||||
logger.info(" No previous checkpoint found, starting fresh")
|
||||
except Exception as e:
|
||||
@@ -1785,28 +1789,35 @@ class RealTrainingAdapter:
|
||||
|
||||
import torch
|
||||
|
||||
# MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
|
||||
def batch_generator():
|
||||
"""Generate batches on-the-fly to avoid memory accumulation"""
|
||||
for i, data in enumerate(training_data):
|
||||
batch = self._convert_annotation_to_transformer_batch(data)
|
||||
if batch is not None:
|
||||
# Repeat based on repetitions parameter
|
||||
repetitions = data.get('repetitions', 1)
|
||||
for _ in range(repetitions):
|
||||
# Yield batch directly without storing
|
||||
yield batch
|
||||
else:
|
||||
logger.warning(f" Failed to convert sample {i+1}")
|
||||
# MEMORY FIX: Pre-convert batches ONCE and cache them
|
||||
# This avoids recreating batches every epoch (major leak!)
|
||||
logger.info(" Pre-converting batches (one-time operation)...")
|
||||
cached_batches = []
|
||||
for i, data in enumerate(training_data):
|
||||
batch = self._convert_annotation_to_transformer_batch(data)
|
||||
if batch is not None:
|
||||
cached_batches.append(batch)
|
||||
else:
|
||||
logger.warning(f" Failed to convert sample {i+1}")
|
||||
|
||||
# Count total batches for logging
|
||||
total_batches = sum(data.get('repetitions', 1) for data in training_data
|
||||
if self._convert_annotation_to_transformer_batch(data) is not None)
|
||||
# Clear training_data to free memory
|
||||
training_data.clear()
|
||||
del training_data
|
||||
gc.collect()
|
||||
|
||||
logger.info(f" Converted {len(cached_batches)} batches, cleared source data")
|
||||
|
||||
def batch_generator():
|
||||
"""Yield pre-converted batches (no recreation)"""
|
||||
for batch in cached_batches:
|
||||
yield batch
|
||||
|
||||
total_batches = len(cached_batches)
|
||||
|
||||
if total_batches == 0:
|
||||
raise Exception("No valid training batches after conversion")
|
||||
|
||||
logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
|
||||
logger.info(f" Ready to train on {total_batches} batches")
|
||||
|
||||
# MEMORY FIX: Process batches directly from generator, no grouping needed
|
||||
# Batch size of 1 (single sample) to avoid OOM
|
||||
@@ -1874,18 +1885,22 @@ class RealTrainingAdapter:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
# MEMORY FIX: Explicit cleanup after EVERY batch
|
||||
del batch
|
||||
del result
|
||||
# Don't delete batch (it's from cache, reused)
|
||||
if 'result' in locals():
|
||||
del result
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# After optimizer step, aggressive cleanup
|
||||
# After optimizer step, aggressive cleanup + memory check
|
||||
if not is_accumulation_step:
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# CRITICAL: Check memory limit
|
||||
memory_usage = memory_guard.check_memory(raise_on_limit=True)
|
||||
|
||||
except torch.cuda.OutOfMemoryError as oom_error:
|
||||
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
|
||||
# Aggressive memory cleanup on OOM
|
||||
@@ -1999,8 +2014,17 @@ class RealTrainingAdapter:
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = avg_accuracy
|
||||
|
||||
# MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
|
||||
# MEMORY FIX: Final cleanup
|
||||
logger.info(" Final memory cleanup...")
|
||||
|
||||
# Clear cached batches
|
||||
for batch in cached_batches:
|
||||
for key in list(batch.keys()):
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
del batch[key]
|
||||
cached_batches.clear()
|
||||
del cached_batches
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user