264 lines
8.5 KiB
Python
264 lines
8.5 KiB
Python
"""
|
|
Memory Guard - Prevents Python process from exceeding memory limit
|
|
|
|
Monitors memory usage and triggers aggressive cleanup when approaching limit.
|
|
Can also enforce hard limit by raising exception.
|
|
"""
|
|
|
|
import psutil
|
|
import gc
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import Optional, Callable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MemoryGuard:
|
|
"""
|
|
Memory usage monitor and enforcer
|
|
|
|
Features:
|
|
- Monitors RAM usage in background
|
|
- Triggers cleanup at warning threshold
|
|
- Raises exception at hard limit
|
|
- Logs memory statistics
|
|
"""
|
|
|
|
def __init__(self,
|
|
max_memory_gb: float = 50.0,
|
|
warning_threshold: float = 0.85,
|
|
check_interval: float = 2.0,
|
|
auto_cleanup: bool = True):
|
|
"""
|
|
Initialize Memory Guard
|
|
|
|
Args:
|
|
max_memory_gb: Maximum allowed memory in GB (default: 50GB)
|
|
warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB)
|
|
check_interval: Check memory every N seconds (default: 5.0)
|
|
auto_cleanup: Automatically run gc.collect() at warning threshold
|
|
"""
|
|
self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
|
|
self.warning_bytes = int(self.max_memory_bytes * warning_threshold)
|
|
self.check_interval = check_interval
|
|
self.auto_cleanup = auto_cleanup
|
|
|
|
self.process = psutil.Process()
|
|
self.running = False
|
|
self.monitor_thread = None
|
|
|
|
self.cleanup_callbacks = []
|
|
self.warning_count = 0
|
|
self.limit_exceeded_count = 0
|
|
|
|
logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB")
|
|
|
|
def start(self):
|
|
"""Start memory monitoring in background thread"""
|
|
if self.running:
|
|
logger.warning("MemoryGuard already running")
|
|
return
|
|
|
|
self.running = True
|
|
self.monitor_thread = threading.Thread(
|
|
target=self._monitor_loop,
|
|
daemon=True
|
|
)
|
|
self.monitor_thread.start()
|
|
logger.info("MemoryGuard monitoring started")
|
|
|
|
def stop(self):
|
|
"""Stop memory monitoring"""
|
|
self.running = False
|
|
if self.monitor_thread:
|
|
self.monitor_thread.join(timeout=2.0)
|
|
logger.info("MemoryGuard monitoring stopped")
|
|
|
|
def register_cleanup_callback(self, callback: Callable):
|
|
"""
|
|
Register a callback to be called when memory warning is triggered
|
|
|
|
Args:
|
|
callback: Function to call for cleanup (no arguments)
|
|
"""
|
|
self.cleanup_callbacks.append(callback)
|
|
|
|
def get_memory_usage(self) -> dict:
|
|
"""Get current memory usage statistics"""
|
|
mem_info = self.process.memory_info()
|
|
rss_bytes = mem_info.rss # Resident Set Size (actual RAM used)
|
|
|
|
return {
|
|
'rss_bytes': rss_bytes,
|
|
'rss_gb': rss_bytes / (1024**3),
|
|
'rss_mb': rss_bytes / (1024**2),
|
|
'max_gb': self.max_memory_bytes / (1024**3),
|
|
'warning_gb': self.warning_bytes / (1024**3),
|
|
'usage_percent': (rss_bytes / self.max_memory_bytes) * 100,
|
|
'at_warning': rss_bytes >= self.warning_bytes,
|
|
'at_limit': rss_bytes >= self.max_memory_bytes
|
|
}
|
|
|
|
def check_memory(self, raise_on_limit: bool = False) -> dict:
|
|
"""
|
|
Check current memory usage and take action if needed
|
|
|
|
Args:
|
|
raise_on_limit: If True, raise MemoryError when limit exceeded
|
|
|
|
Returns:
|
|
Memory usage dict
|
|
|
|
Raises:
|
|
MemoryError: If raise_on_limit=True and limit exceeded
|
|
"""
|
|
usage = self.get_memory_usage()
|
|
|
|
if usage['at_limit']:
|
|
self.limit_exceeded_count += 1
|
|
logger.error(f"MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
|
|
|
# Aggressive cleanup
|
|
self._aggressive_cleanup()
|
|
|
|
# Check again after cleanup
|
|
usage_after = self.get_memory_usage()
|
|
|
|
if raise_on_limit:
|
|
raise MemoryError(
|
|
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
|
|
f"After cleanup: {usage_after['rss_gb']:.2f}GB. "
|
|
f"STOP TRAINING - Memory limit enforced!"
|
|
)
|
|
|
|
elif usage['at_warning']:
|
|
self.warning_count += 1
|
|
logger.warning(f"Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
|
|
|
if self.auto_cleanup:
|
|
self._trigger_cleanup()
|
|
|
|
return usage
|
|
|
|
def _monitor_loop(self):
|
|
"""Background monitoring loop"""
|
|
logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)")
|
|
|
|
while self.running:
|
|
try:
|
|
self.check_memory(raise_on_limit=False)
|
|
time.sleep(self.check_interval)
|
|
except Exception as e:
|
|
logger.error(f"Error in MemoryGuard monitor loop: {e}")
|
|
time.sleep(self.check_interval * 2)
|
|
|
|
def _trigger_cleanup(self):
|
|
"""Trigger cleanup callbacks and garbage collection"""
|
|
logger.info("Triggering memory cleanup...")
|
|
|
|
# Call registered callbacks
|
|
for callback in self.cleanup_callbacks:
|
|
try:
|
|
callback()
|
|
except Exception as e:
|
|
logger.error(f"Error in cleanup callback: {e}")
|
|
|
|
# Run garbage collection
|
|
collected = gc.collect()
|
|
logger.info(f"Garbage collection freed {collected} objects")
|
|
|
|
# Clear CUDA cache if available
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
logger.info("Cleared CUDA cache")
|
|
except:
|
|
pass
|
|
|
|
def _aggressive_cleanup(self):
|
|
"""Aggressive cleanup when limit exceeded"""
|
|
logger.warning("Running AGGRESSIVE memory cleanup...")
|
|
|
|
# Multiple GC passes
|
|
for i in range(3):
|
|
collected = gc.collect()
|
|
logger.info(f"GC pass {i+1}: freed {collected} objects")
|
|
|
|
# Clear CUDA cache
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
logger.info("Cleared CUDA cache and reset stats")
|
|
except:
|
|
pass
|
|
|
|
# Call cleanup callbacks
|
|
for callback in self.cleanup_callbacks:
|
|
try:
|
|
callback()
|
|
except Exception as e:
|
|
logger.error(f"Error in cleanup callback: {e}")
|
|
|
|
# Check if cleanup helped
|
|
usage = self.get_memory_usage()
|
|
logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
|
|
|
|
|
# Global instance
|
|
_memory_guard = None
|
|
|
|
|
|
def get_memory_guard(max_memory_gb: float = 50.0,
|
|
warning_threshold: float = 0.85,
|
|
auto_start: bool = True) -> MemoryGuard:
|
|
"""
|
|
Get or create global MemoryGuard instance
|
|
|
|
Args:
|
|
max_memory_gb: Maximum allowed memory in GB
|
|
warning_threshold: Trigger cleanup at this fraction of max
|
|
auto_start: Automatically start monitoring
|
|
|
|
Returns:
|
|
MemoryGuard instance
|
|
"""
|
|
global _memory_guard
|
|
|
|
if _memory_guard is None:
|
|
_memory_guard = MemoryGuard(
|
|
max_memory_gb=max_memory_gb,
|
|
warning_threshold=warning_threshold
|
|
)
|
|
|
|
if auto_start:
|
|
_memory_guard.start()
|
|
|
|
return _memory_guard
|
|
|
|
|
|
def check_memory(raise_on_limit: bool = False) -> dict:
|
|
"""
|
|
Quick check of current memory usage
|
|
|
|
Args:
|
|
raise_on_limit: If True, raise MemoryError when limit exceeded
|
|
|
|
Returns:
|
|
Memory usage dict
|
|
"""
|
|
guard = get_memory_guard()
|
|
return guard.check_memory(raise_on_limit=raise_on_limit)
|
|
|
|
|
|
def log_memory_usage(prefix: str = ""):
|
|
"""Log current memory usage"""
|
|
usage = check_memory()
|
|
logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|