fix training - loss calculation;
added memory guard
This commit is contained in:
@@ -1744,6 +1744,21 @@ class RealTrainingAdapter:
|
||||
import torch
|
||||
import gc
|
||||
|
||||
# Initialize memory guard (50GB limit)
|
||||
from utils.memory_guard import get_memory_guard, log_memory_usage
|
||||
memory_guard = get_memory_guard(max_memory_gb=50.0, warning_threshold=0.85, auto_start=True)
|
||||
|
||||
# Register cleanup callback
|
||||
def training_cleanup():
|
||||
"""Cleanup callback for memory guard"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
memory_guard.register_cleanup_callback(training_cleanup)
|
||||
log_memory_usage("Training start - ")
|
||||
|
||||
# Load best checkpoint if available to continue training
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
@@ -1828,9 +1843,12 @@ class RealTrainingAdapter:
|
||||
batch_loss = float(result.get('total_loss', 0.0))
|
||||
batch_accuracy = float(result.get('accuracy', 0.0))
|
||||
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
|
||||
batch_trend_accuracy = float(result.get('trend_accuracy', 0.0))
|
||||
batch_action_accuracy = float(result.get('action_accuracy', 0.0))
|
||||
batch_trend_loss = float(result.get('trend_loss', 0.0))
|
||||
batch_candle_loss = float(result.get('candle_loss', 0.0))
|
||||
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||
batch_candle_rmse = result.get('candle_rmse', {})
|
||||
|
||||
epoch_loss += batch_loss
|
||||
epoch_accuracy += batch_accuracy
|
||||
@@ -1838,13 +1856,20 @@ class RealTrainingAdapter:
|
||||
|
||||
# Log first batch and every 5th batch
|
||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||
# Format RMSE values (normalized space)
|
||||
rmse_str = ""
|
||||
if batch_candle_rmse:
|
||||
rmse_str = f", RMSE: O={batch_candle_rmse.get('open', 0):.4f} H={batch_candle_rmse.get('high', 0):.4f} L={batch_candle_rmse.get('low', 0):.4f} C={batch_candle_rmse.get('close', 0):.4f}"
|
||||
|
||||
# Format denormalized RMSE (real prices)
|
||||
denorm_str = ""
|
||||
if batch_candle_loss_denorm:
|
||||
# RMSE values now, much more reasonable
|
||||
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
|
||||
denorm_str = f", Real RMSE: {', '.join(denorm_values)}"
|
||||
|
||||
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, "
|
||||
f"Candle Acc: {batch_accuracy:.1%}, Trend Acc: {batch_trend_accuracy:.1%}, "
|
||||
f"Action Acc: {batch_action_accuracy:.1%}{rmse_str}{denorm_str}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
@@ -1966,6 +1991,9 @@ class RealTrainingAdapter:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Check memory usage
|
||||
log_memory_usage(f" Epoch {epoch + 1} end - ")
|
||||
|
||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
@@ -1978,6 +2006,10 @@ class RealTrainingAdapter:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Final memory check
|
||||
log_memory_usage("Training complete - ")
|
||||
memory_guard.stop()
|
||||
|
||||
# Log best checkpoint info
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
|
||||
Reference in New Issue
Block a user