fix mem leak and train loss

This commit is contained in:
Dobromir Popov
2025-11-12 18:17:21 +02:00
parent 4a5c3fc943
commit 4f43d0d466
2 changed files with 76 additions and 97 deletions

View File

@@ -1766,124 +1766,93 @@ class RealTrainingAdapter:
import torch
# Convert all training samples to transformer format
converted_batches = []
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# Repeat based on repetitions parameter
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors
repetitions = data.get('repetitions', 1)
for _ in range(repetitions):
# Clone all tensors in the batch to ensure independence
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
converted_batches.append(cloned_batch)
else:
logger.warning(f" Failed to convert sample {i+1}")
# MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
def batch_generator():
"""Generate batches on-the-fly to avoid memory accumulation"""
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# Repeat based on repetitions parameter
repetitions = data.get('repetitions', 1)
for _ in range(repetitions):
# Yield batch directly without storing
yield batch
else:
logger.warning(f" Failed to convert sample {i+1}")
if not converted_batches:
# Count total batches for logging
total_batches = sum(data.get('repetitions', 1) for data in training_data
if self._convert_annotation_to_transformer_batch(data) is not None)
if total_batches == 0:
raise Exception("No valid training batches after conversion")
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
# Use batch size of 1 to avoid OOM with large sequence lengths
# With 5 timeframes * 100 candles = 500 sequence positions per sample
# Batch size of 1 ensures we don't exceed GPU memory (8GB)
mini_batch_size = 1 # Process one sample at a time to avoid OOM
# MEMORY FIX: Process batches directly from generator, no grouping needed
# Batch size of 1 (single sample) to avoid OOM
logger.info(f" Processing batches individually (batch_size=1) for memory efficiency")
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
combined: Dict[str, 'torch.Tensor'] = {}
keys = batch_list[0].keys()
for key in keys:
values = [b[key] for b in batch_list if b[key] is not None]
# Skip keys where all values are None
if not values:
combined[key] = None
continue
# Special handling for non-tensor keys (like norm_params which is a dict)
if key == 'norm_params':
# Keep norm_params as a list of dicts (one per sample in batch)
combined[key] = values
continue
# For tensors, concatenate them
try:
combined[key] = torch.cat(values, dim=0)
except (RuntimeError, TypeError) as concat_error:
# If concatenation fails (e.g., not a tensor), keep as list
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
combined[key] = values
return combined
grouped_batches: List[Dict[str, torch.Tensor]] = []
current_group: List[Dict[str, torch.Tensor]] = []
for batch in converted_batches:
current_group.append(batch)
if len(current_group) >= mini_batch_size:
grouped_batches.append(_combine_batches(current_group))
current_group = []
if current_group:
grouped_batches.append(_combine_batches(current_group))
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
# Train using train_step for each mini-batch with gradient accumulation
# Accumulate gradients over multiple batches to simulate larger batch size
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
# MEMORY FIX: Train using generator with aggressive memory cleanup
# Reduced accumulation steps from 5 to 2 for less memory usage
accumulation_steps = 2 # Accumulate 2 batches before optimizer step
import gc
for epoch in range(session.total_epochs):
epoch_loss = 0.0
epoch_accuracy = 0.0
num_batches = 0
# Clear CUDA cache before epoch
# MEMORY FIX: Aggressive cleanup before epoch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
for i, batch in enumerate(grouped_batches):
# Generate batches fresh for each epoch
for i, batch in enumerate(batch_generator()):
try:
# Determine if this is an accumulation step or optimizer step
is_accumulation_step = (i + 1) % accumulation_steps != 0
# Call the trainer's train_step method with proper batch format
# Call the trainer's train_step method
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
if result is not None:
batch_loss = result.get('total_loss', 0.0)
batch_accuracy = result.get('accuracy', 0.0)
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
batch_trend_loss = result.get('trend_loss', 0.0)
batch_candle_loss = result.get('candle_loss', 0.0)
# MEMORY FIX: Detach all tensor values to break computation graph
batch_loss = float(result.get('total_loss', 0.0))
batch_accuracy = float(result.get('accuracy', 0.0))
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
batch_trend_loss = float(result.get('trend_loss', 0.0))
batch_candle_loss = float(result.get('candle_loss', 0.0))
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
epoch_loss += batch_loss
epoch_accuracy += batch_accuracy
num_batches += 1
# Log first batch and every 5th batch for debugging
# Log first batch and every 5th batch
if (i + 1) == 1 or (i + 1) % 5 == 0:
# Format denormalized losses if available
denorm_str = ""
if batch_candle_loss_denorm:
# RMSE values now, much more reasonable
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
denorm_str = f", Real Price Error: {', '.join(denorm_values)}"
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
# NOTE: We do NOT delete batch tensors here because they are reused across epochs
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+
# MEMORY FIX: Explicit cleanup after EVERY batch
del batch
del result
if torch.cuda.is_available():
torch.cuda.empty_cache()
# After optimizer step (not accumulation), force garbage collection
# After optimizer step, aggressive cleanup
if not is_accumulation_step:
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
@@ -1891,13 +1860,27 @@ class RealTrainingAdapter:
except torch.cuda.OutOfMemoryError as oom_error:
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
# Aggressive memory cleanup on OOM
if 'batch' in locals():
del batch
if 'result' in locals():
del result
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset optimizer state to prevent corruption
# Reset optimizer state
trainer.optimizer.zero_grad(set_to_none=True)
logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset")
continue
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")
# Cleanup on error
if 'batch' in locals():
del batch
if 'result' in locals():
del result
gc.collect()
continue
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")
import traceback
@@ -1973,30 +1956,23 @@ class RealTrainingAdapter:
except Exception as e:
logger.warning(f" Failed to save checkpoint: {e}")
# Clear CUDA cache after each epoch
# MEMORY FIX: Aggressive epoch-level cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
session.final_loss = session.current_loss
session.accuracy = avg_accuracy
# Cleanup: Delete batch tensors after all epochs are complete
logger.info(" Cleaning up batch data...")
for batch in grouped_batches:
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
batch.clear()
grouped_batches.clear()
converted_batches.clear()
# Final memory cleanup
# MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
logger.info(" Final memory cleanup...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
torch.cuda.synchronize()
# Log best checkpoint info
try:

View File

@@ -1307,12 +1307,15 @@ class TradingTransformerTrainer:
candle_losses_detail[tf] = tf_loss.item()
# ALSO calculate denormalized loss for better interpretability
# Use RMSE (Root Mean Square Error) instead of MSE for realistic values
if tf in norm_params:
with torch.no_grad():
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
denorm_loss = self.price_criterion(pred_denorm, target_denorm)
candle_losses_denorm[tf] = denorm_loss.item()
# Use RMSE instead of MSE to get interpretable dollar values
mse = torch.mean((pred_denorm - target_denorm) ** 2)
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
candle_losses_denorm[tf] = rmse.item()
# Average loss across available timeframes
if timeframe_losses: