Files
gogo2/utils/memory_guard.py
Dobromir Popov b0b24f36b2 memory leak fixes
2025-11-13 16:05:15 +02:00

264 lines
8.5 KiB
Python

"""
Memory Guard - Prevents Python process from exceeding memory limit
Monitors memory usage and triggers aggressive cleanup when approaching limit.
Can also enforce hard limit by raising exception.
"""
import psutil
import gc
import logging
import threading
import time
from typing import Optional, Callable
logger = logging.getLogger(__name__)
class MemoryGuard:
"""
Memory usage monitor and enforcer
Features:
- Monitors RAM usage in background
- Triggers cleanup at warning threshold
- Raises exception at hard limit
- Logs memory statistics
"""
def __init__(self,
max_memory_gb: float = 50.0,
warning_threshold: float = 0.85,
check_interval: float = 5.0,
auto_cleanup: bool = True):
"""
Initialize Memory Guard
Args:
max_memory_gb: Maximum allowed memory in GB (default: 50GB)
warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB)
check_interval: Check memory every N seconds (default: 5.0)
auto_cleanup: Automatically run gc.collect() at warning threshold
"""
self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
self.warning_bytes = int(self.max_memory_bytes * warning_threshold)
self.check_interval = check_interval
self.auto_cleanup = auto_cleanup
self.process = psutil.Process()
self.running = False
self.monitor_thread = None
self.cleanup_callbacks = []
self.warning_count = 0
self.limit_exceeded_count = 0
logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB")
def start(self):
"""Start memory monitoring in background thread"""
if self.running:
logger.warning("MemoryGuard already running")
return
self.running = True
self.monitor_thread = threading.Thread(
target=self._monitor_loop,
daemon=True
)
self.monitor_thread.start()
logger.info("MemoryGuard monitoring started")
def stop(self):
"""Stop memory monitoring"""
self.running = False
if self.monitor_thread:
self.monitor_thread.join(timeout=2.0)
logger.info("MemoryGuard monitoring stopped")
def register_cleanup_callback(self, callback: Callable):
"""
Register a callback to be called when memory warning is triggered
Args:
callback: Function to call for cleanup (no arguments)
"""
self.cleanup_callbacks.append(callback)
def get_memory_usage(self) -> dict:
"""Get current memory usage statistics"""
mem_info = self.process.memory_info()
rss_bytes = mem_info.rss # Resident Set Size (actual RAM used)
return {
'rss_bytes': rss_bytes,
'rss_gb': rss_bytes / (1024**3),
'rss_mb': rss_bytes / (1024**2),
'max_gb': self.max_memory_bytes / (1024**3),
'warning_gb': self.warning_bytes / (1024**3),
'usage_percent': (rss_bytes / self.max_memory_bytes) * 100,
'at_warning': rss_bytes >= self.warning_bytes,
'at_limit': rss_bytes >= self.max_memory_bytes
}
def check_memory(self, raise_on_limit: bool = False) -> dict:
"""
Check current memory usage and take action if needed
Args:
raise_on_limit: If True, raise MemoryError when limit exceeded
Returns:
Memory usage dict
Raises:
MemoryError: If raise_on_limit=True and limit exceeded
"""
usage = self.get_memory_usage()
if usage['at_limit']:
self.limit_exceeded_count += 1
logger.error(f"MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
# Aggressive cleanup
self._aggressive_cleanup()
# Check again after cleanup
usage_after = self.get_memory_usage()
if raise_on_limit:
raise MemoryError(
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
f"After cleanup: {usage_after['rss_gb']:.2f}GB. "
f"STOP TRAINING - Memory limit enforced!"
)
elif usage['at_warning']:
self.warning_count += 1
logger.warning(f"Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
if self.auto_cleanup:
self._trigger_cleanup()
return usage
def _monitor_loop(self):
"""Background monitoring loop"""
logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)")
while self.running:
try:
self.check_memory(raise_on_limit=False)
time.sleep(self.check_interval)
except Exception as e:
logger.error(f"Error in MemoryGuard monitor loop: {e}")
time.sleep(self.check_interval * 2)
def _trigger_cleanup(self):
"""Trigger cleanup callbacks and garbage collection"""
logger.info("Triggering memory cleanup...")
# Call registered callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
# Run garbage collection
collected = gc.collect()
logger.info(f"Garbage collection freed {collected} objects")
# Clear CUDA cache if available
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info("Cleared CUDA cache")
except:
pass
def _aggressive_cleanup(self):
"""Aggressive cleanup when limit exceeded"""
logger.warning("Running AGGRESSIVE memory cleanup...")
# Multiple GC passes
for i in range(3):
collected = gc.collect()
logger.info(f"GC pass {i+1}: freed {collected} objects")
# Clear CUDA cache
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
logger.info("Cleared CUDA cache and reset stats")
except:
pass
# Call cleanup callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
# Check if cleanup helped
usage = self.get_memory_usage()
logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
# Global instance
_memory_guard = None
def get_memory_guard(max_memory_gb: float = 50.0,
warning_threshold: float = 0.85,
auto_start: bool = True) -> MemoryGuard:
"""
Get or create global MemoryGuard instance
Args:
max_memory_gb: Maximum allowed memory in GB
warning_threshold: Trigger cleanup at this fraction of max
auto_start: Automatically start monitoring
Returns:
MemoryGuard instance
"""
global _memory_guard
if _memory_guard is None:
_memory_guard = MemoryGuard(
max_memory_gb=max_memory_gb,
warning_threshold=warning_threshold
)
if auto_start:
_memory_guard.start()
return _memory_guard
def check_memory(raise_on_limit: bool = False) -> dict:
"""
Quick check of current memory usage
Args:
raise_on_limit: If True, raise MemoryError when limit exceeded
Returns:
Memory usage dict
"""
guard = get_memory_guard()
return guard.check_memory(raise_on_limit=raise_on_limit)
def log_memory_usage(prefix: str = ""):
"""Log current memory usage"""
usage = check_memory()
logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")