memory leak fixes
This commit is contained in:
@@ -1772,6 +1772,10 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
|
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
|
||||||
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
|
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
|
||||||
|
|
||||||
|
# CRITICAL: Delete checkpoint immediately to free memory
|
||||||
|
del checkpoint
|
||||||
|
gc.collect()
|
||||||
else:
|
else:
|
||||||
logger.info(" No previous checkpoint found, starting fresh")
|
logger.info(" No previous checkpoint found, starting fresh")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1785,28 +1789,35 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
|
# MEMORY FIX: Pre-convert batches ONCE and cache them
|
||||||
def batch_generator():
|
# This avoids recreating batches every epoch (major leak!)
|
||||||
"""Generate batches on-the-fly to avoid memory accumulation"""
|
logger.info(" Pre-converting batches (one-time operation)...")
|
||||||
for i, data in enumerate(training_data):
|
cached_batches = []
|
||||||
batch = self._convert_annotation_to_transformer_batch(data)
|
for i, data in enumerate(training_data):
|
||||||
if batch is not None:
|
batch = self._convert_annotation_to_transformer_batch(data)
|
||||||
# Repeat based on repetitions parameter
|
if batch is not None:
|
||||||
repetitions = data.get('repetitions', 1)
|
cached_batches.append(batch)
|
||||||
for _ in range(repetitions):
|
else:
|
||||||
# Yield batch directly without storing
|
logger.warning(f" Failed to convert sample {i+1}")
|
||||||
yield batch
|
|
||||||
else:
|
|
||||||
logger.warning(f" Failed to convert sample {i+1}")
|
|
||||||
|
|
||||||
# Count total batches for logging
|
# Clear training_data to free memory
|
||||||
total_batches = sum(data.get('repetitions', 1) for data in training_data
|
training_data.clear()
|
||||||
if self._convert_annotation_to_transformer_batch(data) is not None)
|
del training_data
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
logger.info(f" Converted {len(cached_batches)} batches, cleared source data")
|
||||||
|
|
||||||
|
def batch_generator():
|
||||||
|
"""Yield pre-converted batches (no recreation)"""
|
||||||
|
for batch in cached_batches:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
total_batches = len(cached_batches)
|
||||||
|
|
||||||
if total_batches == 0:
|
if total_batches == 0:
|
||||||
raise Exception("No valid training batches after conversion")
|
raise Exception("No valid training batches after conversion")
|
||||||
|
|
||||||
logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
|
logger.info(f" Ready to train on {total_batches} batches")
|
||||||
|
|
||||||
# MEMORY FIX: Process batches directly from generator, no grouping needed
|
# MEMORY FIX: Process batches directly from generator, no grouping needed
|
||||||
# Batch size of 1 (single sample) to avoid OOM
|
# Batch size of 1 (single sample) to avoid OOM
|
||||||
@@ -1874,18 +1885,22 @@ class RealTrainingAdapter:
|
|||||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||||
|
|
||||||
# MEMORY FIX: Explicit cleanup after EVERY batch
|
# MEMORY FIX: Explicit cleanup after EVERY batch
|
||||||
del batch
|
# Don't delete batch (it's from cache, reused)
|
||||||
del result
|
if 'result' in locals():
|
||||||
|
del result
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# After optimizer step, aggressive cleanup
|
# After optimizer step, aggressive cleanup + memory check
|
||||||
if not is_accumulation_step:
|
if not is_accumulation_step:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# CRITICAL: Check memory limit
|
||||||
|
memory_usage = memory_guard.check_memory(raise_on_limit=True)
|
||||||
|
|
||||||
except torch.cuda.OutOfMemoryError as oom_error:
|
except torch.cuda.OutOfMemoryError as oom_error:
|
||||||
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
|
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
|
||||||
# Aggressive memory cleanup on OOM
|
# Aggressive memory cleanup on OOM
|
||||||
@@ -1999,8 +2014,17 @@ class RealTrainingAdapter:
|
|||||||
session.final_loss = session.current_loss
|
session.final_loss = session.current_loss
|
||||||
session.accuracy = avg_accuracy
|
session.accuracy = avg_accuracy
|
||||||
|
|
||||||
# MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
|
# MEMORY FIX: Final cleanup
|
||||||
logger.info(" Final memory cleanup...")
|
logger.info(" Final memory cleanup...")
|
||||||
|
|
||||||
|
# Clear cached batches
|
||||||
|
for batch in cached_batches:
|
||||||
|
for key in list(batch.keys()):
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
del batch[key]
|
||||||
|
cached_batches.clear()
|
||||||
|
del cached_batches
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
@@ -118,20 +118,24 @@ class MemoryGuard:
|
|||||||
|
|
||||||
if usage['at_limit']:
|
if usage['at_limit']:
|
||||||
self.limit_exceeded_count += 1
|
self.limit_exceeded_count += 1
|
||||||
logger.error(f"🔴 MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
logger.error(f"MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||||
|
|
||||||
# Aggressive cleanup
|
# Aggressive cleanup
|
||||||
self._aggressive_cleanup()
|
self._aggressive_cleanup()
|
||||||
|
|
||||||
|
# Check again after cleanup
|
||||||
|
usage_after = self.get_memory_usage()
|
||||||
|
|
||||||
if raise_on_limit:
|
if raise_on_limit:
|
||||||
raise MemoryError(
|
raise MemoryError(
|
||||||
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
|
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
|
||||||
f"Increase max_memory_gb or reduce batch size."
|
f"After cleanup: {usage_after['rss_gb']:.2f}GB. "
|
||||||
|
f"STOP TRAINING - Memory limit enforced!"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif usage['at_warning']:
|
elif usage['at_warning']:
|
||||||
self.warning_count += 1
|
self.warning_count += 1
|
||||||
logger.warning(f"⚠️ Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
logger.warning(f"Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||||
|
|
||||||
if self.auto_cleanup:
|
if self.auto_cleanup:
|
||||||
self._trigger_cleanup()
|
self._trigger_cleanup()
|
||||||
|
|||||||
Reference in New Issue
Block a user