fix training - loss calculation;
added memory guard
This commit is contained in:
259
utils/memory_guard.py
Normal file
259
utils/memory_guard.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Memory Guard - Prevents Python process from exceeding memory limit
|
||||
|
||||
Monitors memory usage and triggers aggressive cleanup when approaching limit.
|
||||
Can also enforce hard limit by raising exception.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGuard:
|
||||
"""
|
||||
Memory usage monitor and enforcer
|
||||
|
||||
Features:
|
||||
- Monitors RAM usage in background
|
||||
- Triggers cleanup at warning threshold
|
||||
- Raises exception at hard limit
|
||||
- Logs memory statistics
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_memory_gb: float = 50.0,
|
||||
warning_threshold: float = 0.85,
|
||||
check_interval: float = 5.0,
|
||||
auto_cleanup: bool = True):
|
||||
"""
|
||||
Initialize Memory Guard
|
||||
|
||||
Args:
|
||||
max_memory_gb: Maximum allowed memory in GB (default: 50GB)
|
||||
warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB)
|
||||
check_interval: Check memory every N seconds (default: 5.0)
|
||||
auto_cleanup: Automatically run gc.collect() at warning threshold
|
||||
"""
|
||||
self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
|
||||
self.warning_bytes = int(self.max_memory_bytes * warning_threshold)
|
||||
self.check_interval = check_interval
|
||||
self.auto_cleanup = auto_cleanup
|
||||
|
||||
self.process = psutil.Process()
|
||||
self.running = False
|
||||
self.monitor_thread = None
|
||||
|
||||
self.cleanup_callbacks = []
|
||||
self.warning_count = 0
|
||||
self.limit_exceeded_count = 0
|
||||
|
||||
logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB")
|
||||
|
||||
def start(self):
|
||||
"""Start memory monitoring in background thread"""
|
||||
if self.running:
|
||||
logger.warning("MemoryGuard already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.monitor_thread = threading.Thread(
|
||||
target=self._monitor_loop,
|
||||
daemon=True
|
||||
)
|
||||
self.monitor_thread.start()
|
||||
logger.info("MemoryGuard monitoring started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop memory monitoring"""
|
||||
self.running = False
|
||||
if self.monitor_thread:
|
||||
self.monitor_thread.join(timeout=2.0)
|
||||
logger.info("MemoryGuard monitoring stopped")
|
||||
|
||||
def register_cleanup_callback(self, callback: Callable):
|
||||
"""
|
||||
Register a callback to be called when memory warning is triggered
|
||||
|
||||
Args:
|
||||
callback: Function to call for cleanup (no arguments)
|
||||
"""
|
||||
self.cleanup_callbacks.append(callback)
|
||||
|
||||
def get_memory_usage(self) -> dict:
|
||||
"""Get current memory usage statistics"""
|
||||
mem_info = self.process.memory_info()
|
||||
rss_bytes = mem_info.rss # Resident Set Size (actual RAM used)
|
||||
|
||||
return {
|
||||
'rss_bytes': rss_bytes,
|
||||
'rss_gb': rss_bytes / (1024**3),
|
||||
'rss_mb': rss_bytes / (1024**2),
|
||||
'max_gb': self.max_memory_bytes / (1024**3),
|
||||
'warning_gb': self.warning_bytes / (1024**3),
|
||||
'usage_percent': (rss_bytes / self.max_memory_bytes) * 100,
|
||||
'at_warning': rss_bytes >= self.warning_bytes,
|
||||
'at_limit': rss_bytes >= self.max_memory_bytes
|
||||
}
|
||||
|
||||
def check_memory(self, raise_on_limit: bool = False) -> dict:
|
||||
"""
|
||||
Check current memory usage and take action if needed
|
||||
|
||||
Args:
|
||||
raise_on_limit: If True, raise MemoryError when limit exceeded
|
||||
|
||||
Returns:
|
||||
Memory usage dict
|
||||
|
||||
Raises:
|
||||
MemoryError: If raise_on_limit=True and limit exceeded
|
||||
"""
|
||||
usage = self.get_memory_usage()
|
||||
|
||||
if usage['at_limit']:
|
||||
self.limit_exceeded_count += 1
|
||||
logger.error(f"🔴 MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||
|
||||
# Aggressive cleanup
|
||||
self._aggressive_cleanup()
|
||||
|
||||
if raise_on_limit:
|
||||
raise MemoryError(
|
||||
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
|
||||
f"Increase max_memory_gb or reduce batch size."
|
||||
)
|
||||
|
||||
elif usage['at_warning']:
|
||||
self.warning_count += 1
|
||||
logger.warning(f"⚠️ Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||
|
||||
if self.auto_cleanup:
|
||||
self._trigger_cleanup()
|
||||
|
||||
return usage
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background monitoring loop"""
|
||||
logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self.check_memory(raise_on_limit=False)
|
||||
time.sleep(self.check_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in MemoryGuard monitor loop: {e}")
|
||||
time.sleep(self.check_interval * 2)
|
||||
|
||||
def _trigger_cleanup(self):
|
||||
"""Trigger cleanup callbacks and garbage collection"""
|
||||
logger.info("Triggering memory cleanup...")
|
||||
|
||||
# Call registered callbacks
|
||||
for callback in self.cleanup_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup callback: {e}")
|
||||
|
||||
# Run garbage collection
|
||||
collected = gc.collect()
|
||||
logger.info(f"Garbage collection freed {collected} objects")
|
||||
|
||||
# Clear CUDA cache if available
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
logger.info("Cleared CUDA cache")
|
||||
except:
|
||||
pass
|
||||
|
||||
def _aggressive_cleanup(self):
|
||||
"""Aggressive cleanup when limit exceeded"""
|
||||
logger.warning("Running AGGRESSIVE memory cleanup...")
|
||||
|
||||
# Multiple GC passes
|
||||
for i in range(3):
|
||||
collected = gc.collect()
|
||||
logger.info(f"GC pass {i+1}: freed {collected} objects")
|
||||
|
||||
# Clear CUDA cache
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
logger.info("Cleared CUDA cache and reset stats")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Call cleanup callbacks
|
||||
for callback in self.cleanup_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup callback: {e}")
|
||||
|
||||
# Check if cleanup helped
|
||||
usage = self.get_memory_usage()
|
||||
logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||
|
||||
|
||||
# Global instance
|
||||
_memory_guard = None
|
||||
|
||||
|
||||
def get_memory_guard(max_memory_gb: float = 50.0,
|
||||
warning_threshold: float = 0.85,
|
||||
auto_start: bool = True) -> MemoryGuard:
|
||||
"""
|
||||
Get or create global MemoryGuard instance
|
||||
|
||||
Args:
|
||||
max_memory_gb: Maximum allowed memory in GB
|
||||
warning_threshold: Trigger cleanup at this fraction of max
|
||||
auto_start: Automatically start monitoring
|
||||
|
||||
Returns:
|
||||
MemoryGuard instance
|
||||
"""
|
||||
global _memory_guard
|
||||
|
||||
if _memory_guard is None:
|
||||
_memory_guard = MemoryGuard(
|
||||
max_memory_gb=max_memory_gb,
|
||||
warning_threshold=warning_threshold
|
||||
)
|
||||
|
||||
if auto_start:
|
||||
_memory_guard.start()
|
||||
|
||||
return _memory_guard
|
||||
|
||||
|
||||
def check_memory(raise_on_limit: bool = False) -> dict:
|
||||
"""
|
||||
Quick check of current memory usage
|
||||
|
||||
Args:
|
||||
raise_on_limit: If True, raise MemoryError when limit exceeded
|
||||
|
||||
Returns:
|
||||
Memory usage dict
|
||||
"""
|
||||
guard = get_memory_guard()
|
||||
return guard.check_memory(raise_on_limit=raise_on_limit)
|
||||
|
||||
|
||||
def log_memory_usage(prefix: str = ""):
|
||||
"""Log current memory usage"""
|
||||
usage = check_memory()
|
||||
logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||
Reference in New Issue
Block a user