fix training - loss calculation;
added memory guard
This commit is contained in:
@@ -1744,6 +1744,21 @@ class RealTrainingAdapter:
|
|||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
# Initialize memory guard (50GB limit)
|
||||||
|
from utils.memory_guard import get_memory_guard, log_memory_usage
|
||||||
|
memory_guard = get_memory_guard(max_memory_gb=50.0, warning_threshold=0.85, auto_start=True)
|
||||||
|
|
||||||
|
# Register cleanup callback
|
||||||
|
def training_cleanup():
|
||||||
|
"""Cleanup callback for memory guard"""
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
memory_guard.register_cleanup_callback(training_cleanup)
|
||||||
|
log_memory_usage("Training start - ")
|
||||||
|
|
||||||
# Load best checkpoint if available to continue training
|
# Load best checkpoint if available to continue training
|
||||||
try:
|
try:
|
||||||
checkpoint_dir = "models/checkpoints/transformer"
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
@@ -1828,9 +1843,12 @@ class RealTrainingAdapter:
|
|||||||
batch_loss = float(result.get('total_loss', 0.0))
|
batch_loss = float(result.get('total_loss', 0.0))
|
||||||
batch_accuracy = float(result.get('accuracy', 0.0))
|
batch_accuracy = float(result.get('accuracy', 0.0))
|
||||||
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
|
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
|
||||||
|
batch_trend_accuracy = float(result.get('trend_accuracy', 0.0))
|
||||||
|
batch_action_accuracy = float(result.get('action_accuracy', 0.0))
|
||||||
batch_trend_loss = float(result.get('trend_loss', 0.0))
|
batch_trend_loss = float(result.get('trend_loss', 0.0))
|
||||||
batch_candle_loss = float(result.get('candle_loss', 0.0))
|
batch_candle_loss = float(result.get('candle_loss', 0.0))
|
||||||
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||||
|
batch_candle_rmse = result.get('candle_rmse', {})
|
||||||
|
|
||||||
epoch_loss += batch_loss
|
epoch_loss += batch_loss
|
||||||
epoch_accuracy += batch_accuracy
|
epoch_accuracy += batch_accuracy
|
||||||
@@ -1838,13 +1856,20 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
# Log first batch and every 5th batch
|
# Log first batch and every 5th batch
|
||||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||||
|
# Format RMSE values (normalized space)
|
||||||
|
rmse_str = ""
|
||||||
|
if batch_candle_rmse:
|
||||||
|
rmse_str = f", RMSE: O={batch_candle_rmse.get('open', 0):.4f} H={batch_candle_rmse.get('high', 0):.4f} L={batch_candle_rmse.get('low', 0):.4f} C={batch_candle_rmse.get('close', 0):.4f}"
|
||||||
|
|
||||||
|
# Format denormalized RMSE (real prices)
|
||||||
denorm_str = ""
|
denorm_str = ""
|
||||||
if batch_candle_loss_denorm:
|
if batch_candle_loss_denorm:
|
||||||
# RMSE values now, much more reasonable
|
|
||||||
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||||
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
|
denorm_str = f", Real RMSE: {', '.join(denorm_values)}"
|
||||||
|
|
||||||
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, "
|
||||||
|
f"Candle Acc: {batch_accuracy:.1%}, Trend Acc: {batch_trend_accuracy:.1%}, "
|
||||||
|
f"Action Acc: {batch_action_accuracy:.1%}{rmse_str}{denorm_str}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||||
|
|
||||||
@@ -1966,6 +1991,9 @@ class RealTrainingAdapter:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Check memory usage
|
||||||
|
log_memory_usage(f" Epoch {epoch + 1} end - ")
|
||||||
|
|
||||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
||||||
|
|
||||||
session.final_loss = session.current_loss
|
session.final_loss = session.current_loss
|
||||||
@@ -1978,6 +2006,10 @@ class RealTrainingAdapter:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Final memory check
|
||||||
|
log_memory_usage("Training complete - ")
|
||||||
|
memory_guard.stop()
|
||||||
|
|
||||||
# Log best checkpoint info
|
# Log best checkpoint info
|
||||||
try:
|
try:
|
||||||
checkpoint_dir = "models/checkpoints/transformer"
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
|
|||||||
@@ -1407,23 +1407,63 @@ class TradingTransformerTrainer:
|
|||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
# Calculate accuracy without gradients
|
# Calculate accuracy without gradients
|
||||||
|
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
|
||||||
accuracy = (predictions == batch['actions']).float().mean()
|
|
||||||
|
|
||||||
# Calculate candle prediction accuracy (price direction)
|
|
||||||
candle_accuracy = 0.0
|
candle_accuracy = 0.0
|
||||||
if 'next_candles' in outputs and 'future_prices' in batch:
|
candle_rmse = {}
|
||||||
# Use 1m timeframe prediction as primary
|
|
||||||
if '1m' in outputs['next_candles']:
|
if 'next_candles' in outputs:
|
||||||
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
|
# Use 1m timeframe as primary metric
|
||||||
# Predicted close is the 4th value (index 3)
|
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
|
||||||
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
|
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||||
actual_close_change = batch['future_prices'] # Actual price change ratio
|
actual_candle = batch['future_candle_1m'] # [batch, 5]
|
||||||
|
|
||||||
# Check if direction matches (both positive or both negative)
|
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
|
||||||
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
|
# Calculate RMSE for each OHLCV component
|
||||||
candle_accuracy = direction_match.mean().item()
|
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
|
||||||
|
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
|
||||||
|
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
|
||||||
|
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
|
||||||
|
|
||||||
|
# Average RMSE for OHLC (exclude volume)
|
||||||
|
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
|
||||||
|
|
||||||
|
# Convert to accuracy: lower RMSE = higher accuracy
|
||||||
|
# Normalize by price range
|
||||||
|
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
|
||||||
|
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
|
||||||
|
|
||||||
|
candle_rmse = {
|
||||||
|
'open': rmse_open.item(),
|
||||||
|
'high': rmse_high.item(),
|
||||||
|
'low': rmse_low.item(),
|
||||||
|
'close': rmse_close.item(),
|
||||||
|
'avg': avg_rmse.item()
|
||||||
|
}
|
||||||
|
|
||||||
|
# SECONDARY: Trend vector prediction accuracy
|
||||||
|
trend_accuracy = 0.0
|
||||||
|
if 'trend_analysis' in outputs and 'trend_target' in batch:
|
||||||
|
pred_angle = outputs['trend_analysis']['angle_radians']
|
||||||
|
pred_steepness = outputs['trend_analysis']['steepness']
|
||||||
|
|
||||||
|
actual_angle = batch['trend_target'][:, 0:1]
|
||||||
|
actual_steepness = batch['trend_target'][:, 1:2]
|
||||||
|
|
||||||
|
# Angle error (degrees)
|
||||||
|
angle_error_rad = torch.abs(pred_angle - actual_angle)
|
||||||
|
angle_error_deg = angle_error_rad * 180.0 / 3.14159
|
||||||
|
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
|
||||||
|
|
||||||
|
# Steepness error (percentage)
|
||||||
|
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
|
||||||
|
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
|
||||||
|
|
||||||
|
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
||||||
|
|
||||||
|
# LEGACY: Action accuracy (for comparison)
|
||||||
|
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||||
|
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
|
||||||
|
|
||||||
# Extract values and delete tensors to free memory
|
# Extract values and delete tensors to free memory
|
||||||
result = {
|
result = {
|
||||||
@@ -1433,14 +1473,20 @@ class TradingTransformerTrainer:
|
|||||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
||||||
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
||||||
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
||||||
'accuracy': accuracy.item(),
|
|
||||||
'candle_accuracy': candle_accuracy,
|
# NEW: Realistic accuracy metrics based on next candle prediction
|
||||||
|
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
|
||||||
|
'candle_accuracy': candle_accuracy, # Same as accuracy
|
||||||
|
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
|
||||||
|
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
|
||||||
|
'action_accuracy': action_accuracy, # Legacy action accuracy
|
||||||
|
|
||||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
# CRITICAL: Delete large tensors to free memory immediately
|
# CRITICAL: Delete large tensors to free memory immediately
|
||||||
# This prevents memory accumulation across batches
|
# This prevents memory accumulation across batches
|
||||||
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
|
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
259
utils/memory_guard.py
Normal file
259
utils/memory_guard.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
Memory Guard - Prevents Python process from exceeding memory limit
|
||||||
|
|
||||||
|
Monitors memory usage and triggers aggressive cleanup when approaching limit.
|
||||||
|
Can also enforce hard limit by raising exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryGuard:
|
||||||
|
"""
|
||||||
|
Memory usage monitor and enforcer
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Monitors RAM usage in background
|
||||||
|
- Triggers cleanup at warning threshold
|
||||||
|
- Raises exception at hard limit
|
||||||
|
- Logs memory statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_memory_gb: float = 50.0,
|
||||||
|
warning_threshold: float = 0.85,
|
||||||
|
check_interval: float = 5.0,
|
||||||
|
auto_cleanup: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize Memory Guard
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_memory_gb: Maximum allowed memory in GB (default: 50GB)
|
||||||
|
warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB)
|
||||||
|
check_interval: Check memory every N seconds (default: 5.0)
|
||||||
|
auto_cleanup: Automatically run gc.collect() at warning threshold
|
||||||
|
"""
|
||||||
|
self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
|
||||||
|
self.warning_bytes = int(self.max_memory_bytes * warning_threshold)
|
||||||
|
self.check_interval = check_interval
|
||||||
|
self.auto_cleanup = auto_cleanup
|
||||||
|
|
||||||
|
self.process = psutil.Process()
|
||||||
|
self.running = False
|
||||||
|
self.monitor_thread = None
|
||||||
|
|
||||||
|
self.cleanup_callbacks = []
|
||||||
|
self.warning_count = 0
|
||||||
|
self.limit_exceeded_count = 0
|
||||||
|
|
||||||
|
logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB")
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start memory monitoring in background thread"""
|
||||||
|
if self.running:
|
||||||
|
logger.warning("MemoryGuard already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.monitor_thread = threading.Thread(
|
||||||
|
target=self._monitor_loop,
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
self.monitor_thread.start()
|
||||||
|
logger.info("MemoryGuard monitoring started")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop memory monitoring"""
|
||||||
|
self.running = False
|
||||||
|
if self.monitor_thread:
|
||||||
|
self.monitor_thread.join(timeout=2.0)
|
||||||
|
logger.info("MemoryGuard monitoring stopped")
|
||||||
|
|
||||||
|
def register_cleanup_callback(self, callback: Callable):
|
||||||
|
"""
|
||||||
|
Register a callback to be called when memory warning is triggered
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function to call for cleanup (no arguments)
|
||||||
|
"""
|
||||||
|
self.cleanup_callbacks.append(callback)
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> dict:
|
||||||
|
"""Get current memory usage statistics"""
|
||||||
|
mem_info = self.process.memory_info()
|
||||||
|
rss_bytes = mem_info.rss # Resident Set Size (actual RAM used)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'rss_bytes': rss_bytes,
|
||||||
|
'rss_gb': rss_bytes / (1024**3),
|
||||||
|
'rss_mb': rss_bytes / (1024**2),
|
||||||
|
'max_gb': self.max_memory_bytes / (1024**3),
|
||||||
|
'warning_gb': self.warning_bytes / (1024**3),
|
||||||
|
'usage_percent': (rss_bytes / self.max_memory_bytes) * 100,
|
||||||
|
'at_warning': rss_bytes >= self.warning_bytes,
|
||||||
|
'at_limit': rss_bytes >= self.max_memory_bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_memory(self, raise_on_limit: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Check current memory usage and take action if needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raise_on_limit: If True, raise MemoryError when limit exceeded
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory usage dict
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MemoryError: If raise_on_limit=True and limit exceeded
|
||||||
|
"""
|
||||||
|
usage = self.get_memory_usage()
|
||||||
|
|
||||||
|
if usage['at_limit']:
|
||||||
|
self.limit_exceeded_count += 1
|
||||||
|
logger.error(f"🔴 MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||||
|
|
||||||
|
# Aggressive cleanup
|
||||||
|
self._aggressive_cleanup()
|
||||||
|
|
||||||
|
if raise_on_limit:
|
||||||
|
raise MemoryError(
|
||||||
|
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
|
||||||
|
f"Increase max_memory_gb or reduce batch size."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif usage['at_warning']:
|
||||||
|
self.warning_count += 1
|
||||||
|
logger.warning(f"⚠️ Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||||
|
|
||||||
|
if self.auto_cleanup:
|
||||||
|
self._trigger_cleanup()
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def _monitor_loop(self):
|
||||||
|
"""Background monitoring loop"""
|
||||||
|
logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
self.check_memory(raise_on_limit=False)
|
||||||
|
time.sleep(self.check_interval)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in MemoryGuard monitor loop: {e}")
|
||||||
|
time.sleep(self.check_interval * 2)
|
||||||
|
|
||||||
|
def _trigger_cleanup(self):
|
||||||
|
"""Trigger cleanup callbacks and garbage collection"""
|
||||||
|
logger.info("Triggering memory cleanup...")
|
||||||
|
|
||||||
|
# Call registered callbacks
|
||||||
|
for callback in self.cleanup_callbacks:
|
||||||
|
try:
|
||||||
|
callback()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup callback: {e}")
|
||||||
|
|
||||||
|
# Run garbage collection
|
||||||
|
collected = gc.collect()
|
||||||
|
logger.info(f"Garbage collection freed {collected} objects")
|
||||||
|
|
||||||
|
# Clear CUDA cache if available
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
logger.info("Cleared CUDA cache")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _aggressive_cleanup(self):
|
||||||
|
"""Aggressive cleanup when limit exceeded"""
|
||||||
|
logger.warning("Running AGGRESSIVE memory cleanup...")
|
||||||
|
|
||||||
|
# Multiple GC passes
|
||||||
|
for i in range(3):
|
||||||
|
collected = gc.collect()
|
||||||
|
logger.info(f"GC pass {i+1}: freed {collected} objects")
|
||||||
|
|
||||||
|
# Clear CUDA cache
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
logger.info("Cleared CUDA cache and reset stats")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Call cleanup callbacks
|
||||||
|
for callback in self.cleanup_callbacks:
|
||||||
|
try:
|
||||||
|
callback()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup callback: {e}")
|
||||||
|
|
||||||
|
# Check if cleanup helped
|
||||||
|
usage = self.get_memory_usage()
|
||||||
|
logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
_memory_guard = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_guard(max_memory_gb: float = 50.0,
|
||||||
|
warning_threshold: float = 0.85,
|
||||||
|
auto_start: bool = True) -> MemoryGuard:
|
||||||
|
"""
|
||||||
|
Get or create global MemoryGuard instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_memory_gb: Maximum allowed memory in GB
|
||||||
|
warning_threshold: Trigger cleanup at this fraction of max
|
||||||
|
auto_start: Automatically start monitoring
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryGuard instance
|
||||||
|
"""
|
||||||
|
global _memory_guard
|
||||||
|
|
||||||
|
if _memory_guard is None:
|
||||||
|
_memory_guard = MemoryGuard(
|
||||||
|
max_memory_gb=max_memory_gb,
|
||||||
|
warning_threshold=warning_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if auto_start:
|
||||||
|
_memory_guard.start()
|
||||||
|
|
||||||
|
return _memory_guard
|
||||||
|
|
||||||
|
|
||||||
|
def check_memory(raise_on_limit: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Quick check of current memory usage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raise_on_limit: If True, raise MemoryError when limit exceeded
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory usage dict
|
||||||
|
"""
|
||||||
|
guard = get_memory_guard()
|
||||||
|
return guard.check_memory(raise_on_limit=raise_on_limit)
|
||||||
|
|
||||||
|
|
||||||
|
def log_memory_usage(prefix: str = ""):
|
||||||
|
"""Log current memory usage"""
|
||||||
|
usage = check_memory()
|
||||||
|
logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||||
Reference in New Issue
Block a user