diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 053aadd..f364680 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -1772,6 +1772,10 @@ class RealTrainingAdapter: logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}") logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}") + + # CRITICAL: Delete checkpoint immediately to free memory + del checkpoint + gc.collect() else: logger.info(" No previous checkpoint found, starting fresh") except Exception as e: @@ -1785,28 +1789,35 @@ class RealTrainingAdapter: import torch - # MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory - def batch_generator(): - """Generate batches on-the-fly to avoid memory accumulation""" - for i, data in enumerate(training_data): - batch = self._convert_annotation_to_transformer_batch(data) - if batch is not None: - # Repeat based on repetitions parameter - repetitions = data.get('repetitions', 1) - for _ in range(repetitions): - # Yield batch directly without storing - yield batch - else: - logger.warning(f" Failed to convert sample {i+1}") + # MEMORY FIX: Pre-convert batches ONCE and cache them + # This avoids recreating batches every epoch (major leak!) + logger.info(" Pre-converting batches (one-time operation)...") + cached_batches = [] + for i, data in enumerate(training_data): + batch = self._convert_annotation_to_transformer_batch(data) + if batch is not None: + cached_batches.append(batch) + else: + logger.warning(f" Failed to convert sample {i+1}") - # Count total batches for logging - total_batches = sum(data.get('repetitions', 1) for data in training_data - if self._convert_annotation_to_transformer_batch(data) is not None) + # Clear training_data to free memory + training_data.clear() + del training_data + gc.collect() + + logger.info(f" Converted {len(cached_batches)} batches, cleared source data") + + def batch_generator(): + """Yield pre-converted batches (no recreation)""" + for batch in cached_batches: + yield batch + + total_batches = len(cached_batches) if total_batches == 0: raise Exception("No valid training batches after conversion") - logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)") + logger.info(f" Ready to train on {total_batches} batches") # MEMORY FIX: Process batches directly from generator, no grouping needed # Batch size of 1 (single sample) to avoid OOM @@ -1874,18 +1885,22 @@ class RealTrainingAdapter: logger.warning(f" Batch {i + 1} returned None result - skipping") # MEMORY FIX: Explicit cleanup after EVERY batch - del batch - del result + # Don't delete batch (it's from cache, reused) + if 'result' in locals(): + del result if torch.cuda.is_available(): torch.cuda.empty_cache() - # After optimizer step, aggressive cleanup + # After optimizer step, aggressive cleanup + memory check if not is_accumulation_step: gc.collect() if torch.cuda.is_available(): torch.cuda.synchronize() + # CRITICAL: Check memory limit + memory_usage = memory_guard.check_memory(raise_on_limit=True) + except torch.cuda.OutOfMemoryError as oom_error: logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}") # Aggressive memory cleanup on OOM @@ -1999,8 +2014,17 @@ class RealTrainingAdapter: session.final_loss = session.current_loss session.accuracy = avg_accuracy - # MEMORY FIX: Final cleanup (no batch lists to clean since we used generator) + # MEMORY FIX: Final cleanup logger.info(" Final memory cleanup...") + + # Clear cached batches + for batch in cached_batches: + for key in list(batch.keys()): + if isinstance(batch[key], torch.Tensor): + del batch[key] + cached_batches.clear() + del cached_batches + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/utils/memory_guard.py b/utils/memory_guard.py index d85b235..369a10e 100644 --- a/utils/memory_guard.py +++ b/utils/memory_guard.py @@ -118,20 +118,24 @@ class MemoryGuard: if usage['at_limit']: self.limit_exceeded_count += 1 - logger.error(f"🔴 MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB") + logger.error(f"MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB") # Aggressive cleanup self._aggressive_cleanup() + # Check again after cleanup + usage_after = self.get_memory_usage() + if raise_on_limit: raise MemoryError( f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. " - f"Increase max_memory_gb or reduce batch size." + f"After cleanup: {usage_after['rss_gb']:.2f}GB. " + f"STOP TRAINING - Memory limit enforced!" ) elif usage['at_warning']: self.warning_count += 1 - logger.warning(f"⚠️ Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)") + logger.warning(f"Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)") if self.auto_cleanup: self._trigger_cleanup()