fix training - loss calculation;

added memory guard
This commit is contained in:
Dobromir Popov
2025-11-13 15:58:42 +02:00
parent bf2a6cf96e
commit 13b6fafaf8
3 changed files with 357 additions and 20 deletions

View File

@@ -1744,6 +1744,21 @@ class RealTrainingAdapter:
import torch
import gc
# Initialize memory guard (50GB limit)
from utils.memory_guard import get_memory_guard, log_memory_usage
memory_guard = get_memory_guard(max_memory_gb=50.0, warning_threshold=0.85, auto_start=True)
# Register cleanup callback
def training_cleanup():
"""Cleanup callback for memory guard"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
memory_guard.register_cleanup_callback(training_cleanup)
log_memory_usage("Training start - ")
# Load best checkpoint if available to continue training
try:
checkpoint_dir = "models/checkpoints/transformer"
@@ -1828,9 +1843,12 @@ class RealTrainingAdapter:
batch_loss = float(result.get('total_loss', 0.0))
batch_accuracy = float(result.get('accuracy', 0.0))
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
batch_trend_accuracy = float(result.get('trend_accuracy', 0.0))
batch_action_accuracy = float(result.get('action_accuracy', 0.0))
batch_trend_loss = float(result.get('trend_loss', 0.0))
batch_candle_loss = float(result.get('candle_loss', 0.0))
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
batch_candle_rmse = result.get('candle_rmse', {})
epoch_loss += batch_loss
epoch_accuracy += batch_accuracy
@@ -1838,13 +1856,20 @@ class RealTrainingAdapter:
# Log first batch and every 5th batch
if (i + 1) == 1 or (i + 1) % 5 == 0:
# Format RMSE values (normalized space)
rmse_str = ""
if batch_candle_rmse:
rmse_str = f", RMSE: O={batch_candle_rmse.get('open', 0):.4f} H={batch_candle_rmse.get('high', 0):.4f} L={batch_candle_rmse.get('low', 0):.4f} C={batch_candle_rmse.get('close', 0):.4f}"
# Format denormalized RMSE (real prices)
denorm_str = ""
if batch_candle_loss_denorm:
# RMSE values now, much more reasonable
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
denorm_str = f", Real RMSE: {', '.join(denorm_values)}"
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, "
f"Candle Acc: {batch_accuracy:.1%}, Trend Acc: {batch_trend_accuracy:.1%}, "
f"Action Acc: {batch_action_accuracy:.1%}{rmse_str}{denorm_str}")
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")
@@ -1966,6 +1991,9 @@ class RealTrainingAdapter:
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Check memory usage
log_memory_usage(f" Epoch {epoch + 1} end - ")
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
session.final_loss = session.current_loss
@@ -1978,6 +2006,10 @@ class RealTrainingAdapter:
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Final memory check
log_memory_usage("Training complete - ")
memory_guard.stop()
# Log best checkpoint info
try:
checkpoint_dir = "models/checkpoints/transformer"

View File

@@ -1407,23 +1407,63 @@ class TradingTransformerTrainer:
self.scheduler.step()
# Calculate accuracy without gradients
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
with torch.no_grad():
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
# Calculate candle prediction accuracy (price direction)
candle_accuracy = 0.0
if 'next_candles' in outputs and 'future_prices' in batch:
# Use 1m timeframe prediction as primary
if '1m' in outputs['next_candles']:
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
# Predicted close is the 4th value (index 3)
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
actual_close_change = batch['future_prices'] # Actual price change ratio
candle_rmse = {}
if 'next_candles' in outputs:
# Use 1m timeframe as primary metric
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch['future_candle_1m'] # [batch, 5]
# Check if direction matches (both positive or both negative)
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
candle_accuracy = direction_match.mean().item()
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0
if 'trend_analysis' in outputs and 'trend_target' in batch:
pred_angle = outputs['trend_analysis']['angle_radians']
pred_steepness = outputs['trend_analysis']['steepness']
actual_angle = batch['trend_target'][:, 0:1]
actual_steepness = batch['trend_target'][:, 1:2]
# Angle error (degrees)
angle_error_rad = torch.abs(pred_angle - actual_angle)
angle_error_deg = angle_error_rad * 180.0 / 3.14159
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
# Steepness error (percentage)
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
# LEGACY: Action accuracy (for comparison)
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
# Extract values and delete tensors to free memory
result = {
@@ -1433,14 +1473,20 @@ class TradingTransformerTrainer:
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
# NEW: Realistic accuracy metrics based on next candle prediction
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
'candle_accuracy': candle_accuracy, # Same as accuracy
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
'action_accuracy': action_accuracy, # Legacy action accuracy
'learning_rate': self.scheduler.get_last_lr()[0]
}
# CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
if torch.cuda.is_available():
torch.cuda.empty_cache()

259
utils/memory_guard.py Normal file
View File

@@ -0,0 +1,259 @@
"""
Memory Guard - Prevents Python process from exceeding memory limit
Monitors memory usage and triggers aggressive cleanup when approaching limit.
Can also enforce hard limit by raising exception.
"""
import psutil
import gc
import logging
import threading
import time
from typing import Optional, Callable
logger = logging.getLogger(__name__)
class MemoryGuard:
"""
Memory usage monitor and enforcer
Features:
- Monitors RAM usage in background
- Triggers cleanup at warning threshold
- Raises exception at hard limit
- Logs memory statistics
"""
def __init__(self,
max_memory_gb: float = 50.0,
warning_threshold: float = 0.85,
check_interval: float = 5.0,
auto_cleanup: bool = True):
"""
Initialize Memory Guard
Args:
max_memory_gb: Maximum allowed memory in GB (default: 50GB)
warning_threshold: Trigger cleanup at this fraction of max (default: 0.85 = 42.5GB)
check_interval: Check memory every N seconds (default: 5.0)
auto_cleanup: Automatically run gc.collect() at warning threshold
"""
self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
self.warning_bytes = int(self.max_memory_bytes * warning_threshold)
self.check_interval = check_interval
self.auto_cleanup = auto_cleanup
self.process = psutil.Process()
self.running = False
self.monitor_thread = None
self.cleanup_callbacks = []
self.warning_count = 0
self.limit_exceeded_count = 0
logger.info(f"MemoryGuard initialized: max={max_memory_gb:.1f}GB, warning={max_memory_gb * warning_threshold:.1f}GB")
def start(self):
"""Start memory monitoring in background thread"""
if self.running:
logger.warning("MemoryGuard already running")
return
self.running = True
self.monitor_thread = threading.Thread(
target=self._monitor_loop,
daemon=True
)
self.monitor_thread.start()
logger.info("MemoryGuard monitoring started")
def stop(self):
"""Stop memory monitoring"""
self.running = False
if self.monitor_thread:
self.monitor_thread.join(timeout=2.0)
logger.info("MemoryGuard monitoring stopped")
def register_cleanup_callback(self, callback: Callable):
"""
Register a callback to be called when memory warning is triggered
Args:
callback: Function to call for cleanup (no arguments)
"""
self.cleanup_callbacks.append(callback)
def get_memory_usage(self) -> dict:
"""Get current memory usage statistics"""
mem_info = self.process.memory_info()
rss_bytes = mem_info.rss # Resident Set Size (actual RAM used)
return {
'rss_bytes': rss_bytes,
'rss_gb': rss_bytes / (1024**3),
'rss_mb': rss_bytes / (1024**2),
'max_gb': self.max_memory_bytes / (1024**3),
'warning_gb': self.warning_bytes / (1024**3),
'usage_percent': (rss_bytes / self.max_memory_bytes) * 100,
'at_warning': rss_bytes >= self.warning_bytes,
'at_limit': rss_bytes >= self.max_memory_bytes
}
def check_memory(self, raise_on_limit: bool = False) -> dict:
"""
Check current memory usage and take action if needed
Args:
raise_on_limit: If True, raise MemoryError when limit exceeded
Returns:
Memory usage dict
Raises:
MemoryError: If raise_on_limit=True and limit exceeded
"""
usage = self.get_memory_usage()
if usage['at_limit']:
self.limit_exceeded_count += 1
logger.error(f"🔴 MEMORY LIMIT EXCEEDED: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
# Aggressive cleanup
self._aggressive_cleanup()
if raise_on_limit:
raise MemoryError(
f"Memory limit exceeded: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB. "
f"Increase max_memory_gb or reduce batch size."
)
elif usage['at_warning']:
self.warning_count += 1
logger.warning(f"⚠️ Memory warning: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")
if self.auto_cleanup:
self._trigger_cleanup()
return usage
def _monitor_loop(self):
"""Background monitoring loop"""
logger.info(f"MemoryGuard monitoring loop started (checking every {self.check_interval}s)")
while self.running:
try:
self.check_memory(raise_on_limit=False)
time.sleep(self.check_interval)
except Exception as e:
logger.error(f"Error in MemoryGuard monitor loop: {e}")
time.sleep(self.check_interval * 2)
def _trigger_cleanup(self):
"""Trigger cleanup callbacks and garbage collection"""
logger.info("Triggering memory cleanup...")
# Call registered callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
# Run garbage collection
collected = gc.collect()
logger.info(f"Garbage collection freed {collected} objects")
# Clear CUDA cache if available
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info("Cleared CUDA cache")
except:
pass
def _aggressive_cleanup(self):
"""Aggressive cleanup when limit exceeded"""
logger.warning("Running AGGRESSIVE memory cleanup...")
# Multiple GC passes
for i in range(3):
collected = gc.collect()
logger.info(f"GC pass {i+1}: freed {collected} objects")
# Clear CUDA cache
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
logger.info("Cleared CUDA cache and reset stats")
except:
pass
# Call cleanup callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Error in cleanup callback: {e}")
# Check if cleanup helped
usage = self.get_memory_usage()
logger.info(f"After cleanup: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB")
# Global instance
_memory_guard = None
def get_memory_guard(max_memory_gb: float = 50.0,
warning_threshold: float = 0.85,
auto_start: bool = True) -> MemoryGuard:
"""
Get or create global MemoryGuard instance
Args:
max_memory_gb: Maximum allowed memory in GB
warning_threshold: Trigger cleanup at this fraction of max
auto_start: Automatically start monitoring
Returns:
MemoryGuard instance
"""
global _memory_guard
if _memory_guard is None:
_memory_guard = MemoryGuard(
max_memory_gb=max_memory_gb,
warning_threshold=warning_threshold
)
if auto_start:
_memory_guard.start()
return _memory_guard
def check_memory(raise_on_limit: bool = False) -> dict:
"""
Quick check of current memory usage
Args:
raise_on_limit: If True, raise MemoryError when limit exceeded
Returns:
Memory usage dict
"""
guard = get_memory_guard()
return guard.check_memory(raise_on_limit=raise_on_limit)
def log_memory_usage(prefix: str = ""):
"""Log current memory usage"""
usage = check_memory()
logger.info(f"{prefix}Memory: {usage['rss_gb']:.2f}GB / {usage['max_gb']:.1f}GB ({usage['usage_percent']:.1f}%)")