""" Memory Guard - Prevents Python process from exceeding memory limit Monitors memory usage and triggers aggressive cleanup when approaching limit. Can also enforce hard limit by raising exception. """ import psutil import gc import logging import threading import time from typing import Optional, Callable logger = logging.getLogger(__name__) class MemoryGuard: """ Memory usage monitor and enforcer Features: - Monitors RAM usage in background - Triggers cleanup at warning threshold - Raises exception at hard limit - Logs memory statistics """ def __init__(self, max_memory_gb: float = 50.0, warning_threshold: float = 0.85, check_interval: float = 2.0, auto_cleanup: bool = True): """ Initialize Memory Guard Args: max_memory_gb: Maximum allowed memory in GB (default: 50GB) warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB) check_interval: Check memory every N seconds (default: 5.0) auto_cleanup: Automatically run gc.collect() at warning threshold """ self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024) self.warning_bytes = int(self.max_memory_bytes * warning_threshold) self.check_interval = check_interval self.auto_cleanup = auto_cleanup self.process = psutil.Process() self.running = False self.monitor_thread = None self.cleanup_callbacks = [] self.warning_count = 0 self.limit_exceeded_count = 0 logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB") def start(self): """Start memory monitoring in background thread""" if self.running: logger.warning("MemoryGuard already running") return self.running = True self.monitor_thread = threading.Thread( target=self._monitor_loop, daemon=True ) self.monitor_thread.start() logger.info("MemoryGuard monitoring started") def stop(self): """Stop memory monitoring""" self.running = False if self.monitor_thread: self.monitor_thread.join(timeout=2.0) logger.info("MemoryGuard monitoring stopped") def register_cleanup_callback(self, callback: Callable): """ Register a callback to be called when memory warning is triggered Args: callback: Function to call for cleanup (no arguments) """ self.cleanup_callbacks.append(callback) def get_memory_usage(self) -> dict: """Get current memory usage statistics""" mem_info = self.process.memory_info() rss_bytes = mem_info.rss # Resident Set Size (actual RAM used) return { 'rss_bytes': rss_bytes, 'rss_gb': rss_bytes / (1024**3), 'rss_mb': rss_bytes / (1024**2), 'max_gb': self.max_memory_bytes / (1024**3), 'warning_gb': self.warning_bytes / (1024**3), 'usage_percent': (rss_bytes / self.max_memory_bytes) * 100, 'at_warning': rss_bytes >= self.warning_bytes, 'at_limit': rss_bytes >= self.max_memory_bytes } def check_memory(self, raise_on_limit: bool = False) -> dict: """ Check current memory usage and take action if needed Args: raise_on_limit: If True, raise MemoryError when limit exceeded Returns: Memory usage dict Raises: MemoryError: If raise_on_limit=True and limit exceeded """ usage = self.get_memory_usage() if usage['at_limit']: self.limit_exceeded_count += 1 logger.error(f"MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB") # Aggressive cleanup self._aggressive_cleanup() # Check again after cleanup usage_after = self.get_memory_usage() if raise_on_limit: raise MemoryError( f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. " f"After cleanup: {usage_after['rss_gb']:.2f}GB. " f"STOP TRAINING - Memory limit enforced!" ) elif usage['at_warning']: self.warning_count += 1 logger.warning(f"Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)") if self.auto_cleanup: self._trigger_cleanup() return usage def _monitor_loop(self): """Background monitoring loop""" logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)") while self.running: try: self.check_memory(raise_on_limit=False) time.sleep(self.check_interval) except Exception as e: logger.error(f"Error in MemoryGuard monitor loop: {e}") time.sleep(self.check_interval * 2) def _trigger_cleanup(self): """Trigger cleanup callbacks and garbage collection""" logger.info("Triggering memory cleanup...") # Call registered callbacks for callback in self.cleanup_callbacks: try: callback() except Exception as e: logger.error(f"Error in cleanup callback: {e}") # Run garbage collection collected = gc.collect() logger.info(f"Garbage collection freed {collected} objects") # Clear CUDA cache if available try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() logger.info("Cleared CUDA cache") except: pass def _aggressive_cleanup(self): """Aggressive cleanup when limit exceeded""" logger.warning("Running AGGRESSIVE memory cleanup...") # Multiple GC passes for i in range(3): collected = gc.collect() logger.info(f"GC pass {i+1}: freed {collected} objects") # Clear CUDA cache try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() logger.info("Cleared CUDA cache and reset stats") except: pass # Call cleanup callbacks for callback in self.cleanup_callbacks: try: callback() except Exception as e: logger.error(f"Error in cleanup callback: {e}") # Check if cleanup helped usage = self.get_memory_usage() logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB") # Global instance _memory_guard = None def get_memory_guard(max_memory_gb: float = 50.0, warning_threshold: float = 0.85, auto_start: bool = True) -> MemoryGuard: """ Get or create global MemoryGuard instance Args: max_memory_gb: Maximum allowed memory in GB warning_threshold: Trigger cleanup at this fraction of max auto_start: Automatically start monitoring Returns: MemoryGuard instance """ global _memory_guard if _memory_guard is None: _memory_guard = MemoryGuard( max_memory_gb=max_memory_gb, warning_threshold=warning_threshold ) if auto_start: _memory_guard.start() return _memory_guard def check_memory(raise_on_limit: bool = False) -> dict: """ Quick check of current memory usage Args: raise_on_limit: If True, raise MemoryError when limit exceeded Returns: Memory usage dict """ guard = get_memory_guard() return guard.check_memory(raise_on_limit=raise_on_limit) def log_memory_usage(prefix: str = ""): """Log current memory usage""" usage = check_memory() logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")