fix training - loss calculation;
added memory guard
This commit is contained in:
@@ -1744,6 +1744,21 @@ class RealTrainingAdapter:
|
||||
import torch
|
||||
import gc
|
||||
|
||||
# Initialize memory guard (50GB limit)
|
||||
from utils.memory_guard import get_memory_guard, log_memory_usage
|
||||
memory_guard = get_memory_guard(max_memory_gb=50.0, warning_threshold=0.85, auto_start=True)
|
||||
|
||||
# Register cleanup callback
|
||||
def training_cleanup():
|
||||
"""Cleanup callback for memory guard"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
memory_guard.register_cleanup_callback(training_cleanup)
|
||||
log_memory_usage("Training start - ")
|
||||
|
||||
# Load best checkpoint if available to continue training
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
@@ -1828,9 +1843,12 @@ class RealTrainingAdapter:
|
||||
batch_loss = float(result.get('total_loss', 0.0))
|
||||
batch_accuracy = float(result.get('accuracy', 0.0))
|
||||
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
|
||||
batch_trend_accuracy = float(result.get('trend_accuracy', 0.0))
|
||||
batch_action_accuracy = float(result.get('action_accuracy', 0.0))
|
||||
batch_trend_loss = float(result.get('trend_loss', 0.0))
|
||||
batch_candle_loss = float(result.get('candle_loss', 0.0))
|
||||
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||
batch_candle_rmse = result.get('candle_rmse', {})
|
||||
|
||||
epoch_loss += batch_loss
|
||||
epoch_accuracy += batch_accuracy
|
||||
@@ -1838,13 +1856,20 @@ class RealTrainingAdapter:
|
||||
|
||||
# Log first batch and every 5th batch
|
||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||
# Format RMSE values (normalized space)
|
||||
rmse_str = ""
|
||||
if batch_candle_rmse:
|
||||
rmse_str = f", RMSE: O={batch_candle_rmse.get('open', 0):.4f} H={batch_candle_rmse.get('high', 0):.4f} L={batch_candle_rmse.get('low', 0):.4f} C={batch_candle_rmse.get('close', 0):.4f}"
|
||||
|
||||
# Format denormalized RMSE (real prices)
|
||||
denorm_str = ""
|
||||
if batch_candle_loss_denorm:
|
||||
# RMSE values now, much more reasonable
|
||||
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
|
||||
denorm_str = f", Real RMSE: {', '.join(denorm_values)}"
|
||||
|
||||
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, "
|
||||
f"Candle Acc: {batch_accuracy:.1%}, Trend Acc: {batch_trend_accuracy:.1%}, "
|
||||
f"Action Acc: {batch_action_accuracy:.1%}{rmse_str}{denorm_str}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
@@ -1966,6 +1991,9 @@ class RealTrainingAdapter:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Check memory usage
|
||||
log_memory_usage(f" Epoch {epoch + 1} end - ")
|
||||
|
||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
@@ -1978,6 +2006,10 @@ class RealTrainingAdapter:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Final memory check
|
||||
log_memory_usage("Training complete - ")
|
||||
memory_guard.stop()
|
||||
|
||||
# Log best checkpoint info
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
|
||||
@@ -1407,23 +1407,63 @@ class TradingTransformerTrainer:
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy without gradients
|
||||
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
|
||||
with torch.no_grad():
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
# Calculate candle prediction accuracy (price direction)
|
||||
candle_accuracy = 0.0
|
||||
if 'next_candles' in outputs and 'future_prices' in batch:
|
||||
# Use 1m timeframe prediction as primary
|
||||
if '1m' in outputs['next_candles']:
|
||||
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
# Predicted close is the 4th value (index 3)
|
||||
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
|
||||
actual_close_change = batch['future_prices'] # Actual price change ratio
|
||||
candle_rmse = {}
|
||||
|
||||
if 'next_candles' in outputs:
|
||||
# Use 1m timeframe as primary metric
|
||||
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
|
||||
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
actual_candle = batch['future_candle_1m'] # [batch, 5]
|
||||
|
||||
# Check if direction matches (both positive or both negative)
|
||||
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
|
||||
candle_accuracy = direction_match.mean().item()
|
||||
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
|
||||
# Calculate RMSE for each OHLCV component
|
||||
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
|
||||
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
|
||||
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
|
||||
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
|
||||
|
||||
# Average RMSE for OHLC (exclude volume)
|
||||
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
|
||||
|
||||
# Convert to accuracy: lower RMSE = higher accuracy
|
||||
# Normalize by price range
|
||||
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
|
||||
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
|
||||
|
||||
candle_rmse = {
|
||||
'open': rmse_open.item(),
|
||||
'high': rmse_high.item(),
|
||||
'low': rmse_low.item(),
|
||||
'close': rmse_close.item(),
|
||||
'avg': avg_rmse.item()
|
||||
}
|
||||
|
||||
# SECONDARY: Trend vector prediction accuracy
|
||||
trend_accuracy = 0.0
|
||||
if 'trend_analysis' in outputs and 'trend_target' in batch:
|
||||
pred_angle = outputs['trend_analysis']['angle_radians']
|
||||
pred_steepness = outputs['trend_analysis']['steepness']
|
||||
|
||||
actual_angle = batch['trend_target'][:, 0:1]
|
||||
actual_steepness = batch['trend_target'][:, 1:2]
|
||||
|
||||
# Angle error (degrees)
|
||||
angle_error_rad = torch.abs(pred_angle - actual_angle)
|
||||
angle_error_deg = angle_error_rad * 180.0 / 3.14159
|
||||
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
|
||||
|
||||
# Steepness error (percentage)
|
||||
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
|
||||
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
|
||||
|
||||
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
||||
|
||||
# LEGACY: Action accuracy (for comparison)
|
||||
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
|
||||
|
||||
# Extract values and delete tensors to free memory
|
||||
result = {
|
||||
@@ -1433,14 +1473,20 @@ class TradingTransformerTrainer:
|
||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
||||
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
||||
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
||||
'accuracy': accuracy.item(),
|
||||
'candle_accuracy': candle_accuracy,
|
||||
|
||||
# NEW: Realistic accuracy metrics based on next candle prediction
|
||||
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
|
||||
'candle_accuracy': candle_accuracy, # Same as accuracy
|
||||
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
|
||||
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
|
||||
'action_accuracy': action_accuracy, # Legacy action accuracy
|
||||
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
# CRITICAL: Delete large tensors to free memory immediately
|
||||
# This prevents memory accumulation across batches
|
||||
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
|
||||
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
259
utils/memory_guard.py
Normal file
259
utils/memory_guard.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Memory Guard - Prevents Python process from exceeding memory limit
|
||||
|
||||
Monitors memory usage and triggers aggressive cleanup when approaching limit.
|
||||
Can also enforce hard limit by raising exception.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGuard:
|
||||
"""
|
||||
Memory usage monitor and enforcer
|
||||
|
||||
Features:
|
||||
- Monitors RAM usage in background
|
||||
- Triggers cleanup at warning threshold
|
||||
- Raises exception at hard limit
|
||||
- Logs memory statistics
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_memory_gb: float = 50.0,
|
||||
warning_threshold: float = 0.85,
|
||||
check_interval: float = 5.0,
|
||||
auto_cleanup: bool = True):
|
||||
"""
|
||||
Initialize Memory Guard
|
||||
|
||||
Args:
|
||||
max_memory_gb: Maximum allowed memory in GB (default: 50GB)
|
||||
warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB)
|
||||
check_interval: Check memory every N seconds (default: 5.0)
|
||||
auto_cleanup: Automatically run gc.collect() at warning threshold
|
||||
"""
|
||||
self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
|
||||
self.warning_bytes = int(self.max_memory_bytes * warning_threshold)
|
||||
self.check_interval = check_interval
|
||||
self.auto_cleanup = auto_cleanup
|
||||
|
||||
self.process = psutil.Process()
|
||||
self.running = False
|
||||
self.monitor_thread = None
|
||||
|
||||
self.cleanup_callbacks = []
|
||||
self.warning_count = 0
|
||||
self.limit_exceeded_count = 0
|
||||
|
||||
logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB")
|
||||
|
||||
def start(self):
|
||||
"""Start memory monitoring in background thread"""
|
||||
if self.running:
|
||||
logger.warning("MemoryGuard already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.monitor_thread = threading.Thread(
|
||||
target=self._monitor_loop,
|
||||
daemon=True
|
||||
)
|
||||
self.monitor_thread.start()
|
||||
logger.info("MemoryGuard monitoring started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop memory monitoring"""
|
||||
self.running = False
|
||||
if self.monitor_thread:
|
||||
self.monitor_thread.join(timeout=2.0)
|
||||
logger.info("MemoryGuard monitoring stopped")
|
||||
|
||||
def register_cleanup_callback(self, callback: Callable):
|
||||
"""
|
||||
Register a callback to be called when memory warning is triggered
|
||||
|
||||
Args:
|
||||
callback: Function to call for cleanup (no arguments)
|
||||
"""
|
||||
self.cleanup_callbacks.append(callback)
|
||||
|
||||
def get_memory_usage(self) -> dict:
|
||||
"""Get current memory usage statistics"""
|
||||
mem_info = self.process.memory_info()
|
||||
rss_bytes = mem_info.rss # Resident Set Size (actual RAM used)
|
||||
|
||||
return {
|
||||
'rss_bytes': rss_bytes,
|
||||
'rss_gb': rss_bytes / (1024**3),
|
||||
'rss_mb': rss_bytes / (1024**2),
|
||||
'max_gb': self.max_memory_bytes / (1024**3),
|
||||
'warning_gb': self.warning_bytes / (1024**3),
|
||||
'usage_percent': (rss_bytes / self.max_memory_bytes) * 100,
|
||||
'at_warning': rss_bytes >= self.warning_bytes,
|
||||
'at_limit': rss_bytes >= self.max_memory_bytes
|
||||
}
|
||||
|
||||
def check_memory(self, raise_on_limit: bool = False) -> dict:
|
||||
"""
|
||||
Check current memory usage and take action if needed
|
||||
|
||||
Args:
|
||||
raise_on_limit: If True, raise MemoryError when limit exceeded
|
||||
|
||||
Returns:
|
||||
Memory usage dict
|
||||
|
||||
Raises:
|
||||
MemoryError: If raise_on_limit=True and limit exceeded
|
||||
"""
|
||||
usage = self.get_memory_usage()
|
||||
|
||||
if usage['at_limit']:
|
||||
self.limit_exceeded_count += 1
|
||||
logger.error(f"🔴 MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||
|
||||
# Aggressive cleanup
|
||||
self._aggressive_cleanup()
|
||||
|
||||
if raise_on_limit:
|
||||
raise MemoryError(
|
||||
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
|
||||
f"Increase max_memory_gb or reduce batch size."
|
||||
)
|
||||
|
||||
elif usage['at_warning']:
|
||||
self.warning_count += 1
|
||||
logger.warning(f"⚠️ Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||
|
||||
if self.auto_cleanup:
|
||||
self._trigger_cleanup()
|
||||
|
||||
return usage
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background monitoring loop"""
|
||||
logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self.check_memory(raise_on_limit=False)
|
||||
time.sleep(self.check_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in MemoryGuard monitor loop: {e}")
|
||||
time.sleep(self.check_interval * 2)
|
||||
|
||||
def _trigger_cleanup(self):
|
||||
"""Trigger cleanup callbacks and garbage collection"""
|
||||
logger.info("Triggering memory cleanup...")
|
||||
|
||||
# Call registered callbacks
|
||||
for callback in self.cleanup_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup callback: {e}")
|
||||
|
||||
# Run garbage collection
|
||||
collected = gc.collect()
|
||||
logger.info(f"Garbage collection freed {collected} objects")
|
||||
|
||||
# Clear CUDA cache if available
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
logger.info("Cleared CUDA cache")
|
||||
except:
|
||||
pass
|
||||
|
||||
def _aggressive_cleanup(self):
|
||||
"""Aggressive cleanup when limit exceeded"""
|
||||
logger.warning("Running AGGRESSIVE memory cleanup...")
|
||||
|
||||
# Multiple GC passes
|
||||
for i in range(3):
|
||||
collected = gc.collect()
|
||||
logger.info(f"GC pass {i+1}: freed {collected} objects")
|
||||
|
||||
# Clear CUDA cache
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
logger.info("Cleared CUDA cache and reset stats")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Call cleanup callbacks
|
||||
for callback in self.cleanup_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup callback: {e}")
|
||||
|
||||
# Check if cleanup helped
|
||||
usage = self.get_memory_usage()
|
||||
logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
|
||||
|
||||
|
||||
# Global instance
|
||||
_memory_guard = None
|
||||
|
||||
|
||||
def get_memory_guard(max_memory_gb: float = 50.0,
|
||||
warning_threshold: float = 0.85,
|
||||
auto_start: bool = True) -> MemoryGuard:
|
||||
"""
|
||||
Get or create global MemoryGuard instance
|
||||
|
||||
Args:
|
||||
max_memory_gb: Maximum allowed memory in GB
|
||||
warning_threshold: Trigger cleanup at this fraction of max
|
||||
auto_start: Automatically start monitoring
|
||||
|
||||
Returns:
|
||||
MemoryGuard instance
|
||||
"""
|
||||
global _memory_guard
|
||||
|
||||
if _memory_guard is None:
|
||||
_memory_guard = MemoryGuard(
|
||||
max_memory_gb=max_memory_gb,
|
||||
warning_threshold=warning_threshold
|
||||
)
|
||||
|
||||
if auto_start:
|
||||
_memory_guard.start()
|
||||
|
||||
return _memory_guard
|
||||
|
||||
|
||||
def check_memory(raise_on_limit: bool = False) -> dict:
|
||||
"""
|
||||
Quick check of current memory usage
|
||||
|
||||
Args:
|
||||
raise_on_limit: If True, raise MemoryError when limit exceeded
|
||||
|
||||
Returns:
|
||||
Memory usage dict
|
||||
"""
|
||||
guard = get_memory_guard()
|
||||
return guard.check_memory(raise_on_limit=raise_on_limit)
|
||||
|
||||
|
||||
def log_memory_usage(prefix: str = ""):
|
||||
"""Log current memory usage"""
|
||||
usage = check_memory()
|
||||
logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
|
||||
Reference in New Issue
Block a user