fix mem leak and train loss

This commit is contained in:
Dobromir Popov
2025-11-12 18:17:21 +02:00
parent 4a5c3fc943
commit 4f43d0d466
2 changed files with 76 additions and 97 deletions

View File

@@ -1766,124 +1766,93 @@ class RealTrainingAdapter:
import torch import torch
# Convert all training samples to transformer format # MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
converted_batches = [] def batch_generator():
for i, data in enumerate(training_data): """Generate batches on-the-fly to avoid memory accumulation"""
batch = self._convert_annotation_to_transformer_batch(data) for i, data in enumerate(training_data):
if batch is not None: batch = self._convert_annotation_to_transformer_batch(data)
# Repeat based on repetitions parameter if batch is not None:
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors # Repeat based on repetitions parameter
repetitions = data.get('repetitions', 1) repetitions = data.get('repetitions', 1)
for _ in range(repetitions): for _ in range(repetitions):
# Clone all tensors in the batch to ensure independence # Yield batch directly without storing
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v yield batch
for k, v in batch.items()} else:
converted_batches.append(cloned_batch) logger.warning(f" Failed to convert sample {i+1}")
else:
logger.warning(f" Failed to convert sample {i+1}")
if not converted_batches: # Count total batches for logging
total_batches = sum(data.get('repetitions', 1) for data in training_data
if self._convert_annotation_to_transformer_batch(data) is not None)
if total_batches == 0:
raise Exception("No valid training batches after conversion") raise Exception("No valid training batches after conversion")
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches") logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
# Use batch size of 1 to avoid OOM with large sequence lengths # MEMORY FIX: Process batches directly from generator, no grouping needed
# With 5 timeframes * 100 candles = 500 sequence positions per sample # Batch size of 1 (single sample) to avoid OOM
# Batch size of 1 ensures we don't exceed GPU memory (8GB) logger.info(f" Processing batches individually (batch_size=1) for memory efficiency")
mini_batch_size = 1 # Process one sample at a time to avoid OOM
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']: # MEMORY FIX: Train using generator with aggressive memory cleanup
combined: Dict[str, 'torch.Tensor'] = {} # Reduced accumulation steps from 5 to 2 for less memory usage
keys = batch_list[0].keys() accumulation_steps = 2 # Accumulate 2 batches before optimizer step
for key in keys:
values = [b[key] for b in batch_list if b[key] is not None]
# Skip keys where all values are None
if not values:
combined[key] = None
continue
# Special handling for non-tensor keys (like norm_params which is a dict) import gc
if key == 'norm_params':
# Keep norm_params as a list of dicts (one per sample in batch)
combined[key] = values
continue
# For tensors, concatenate them
try:
combined[key] = torch.cat(values, dim=0)
except (RuntimeError, TypeError) as concat_error:
# If concatenation fails (e.g., not a tensor), keep as list
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
combined[key] = values
return combined
grouped_batches: List[Dict[str, torch.Tensor]] = []
current_group: List[Dict[str, torch.Tensor]] = []
for batch in converted_batches:
current_group.append(batch)
if len(current_group) >= mini_batch_size:
grouped_batches.append(_combine_batches(current_group))
current_group = []
if current_group:
grouped_batches.append(_combine_batches(current_group))
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
# Train using train_step for each mini-batch with gradient accumulation
# Accumulate gradients over multiple batches to simulate larger batch size
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
for epoch in range(session.total_epochs): for epoch in range(session.total_epochs):
epoch_loss = 0.0 epoch_loss = 0.0
epoch_accuracy = 0.0 epoch_accuracy = 0.0
num_batches = 0 num_batches = 0
# Clear CUDA cache before epoch # MEMORY FIX: Aggressive cleanup before epoch
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize()
for i, batch in enumerate(grouped_batches): # Generate batches fresh for each epoch
for i, batch in enumerate(batch_generator()):
try: try:
# Determine if this is an accumulation step or optimizer step # Determine if this is an accumulation step or optimizer step
is_accumulation_step = (i + 1) % accumulation_steps != 0 is_accumulation_step = (i + 1) % accumulation_steps != 0
# Call the trainer's train_step method with proper batch format # Call the trainer's train_step method
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step) result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
if result is not None: if result is not None:
batch_loss = result.get('total_loss', 0.0) # MEMORY FIX: Detach all tensor values to break computation graph
batch_accuracy = result.get('accuracy', 0.0) batch_loss = float(result.get('total_loss', 0.0))
batch_candle_accuracy = result.get('candle_accuracy', 0.0) batch_accuracy = float(result.get('accuracy', 0.0))
batch_trend_loss = result.get('trend_loss', 0.0) batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
batch_candle_loss = result.get('candle_loss', 0.0) batch_trend_loss = float(result.get('trend_loss', 0.0))
batch_candle_loss = float(result.get('candle_loss', 0.0))
batch_candle_loss_denorm = result.get('candle_loss_denorm', {}) batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
epoch_loss += batch_loss epoch_loss += batch_loss
epoch_accuracy += batch_accuracy epoch_accuracy += batch_accuracy
num_batches += 1 num_batches += 1
# Log first batch and every 5th batch for debugging # Log first batch and every 5th batch
if (i + 1) == 1 or (i + 1) % 5 == 0: if (i + 1) == 1 or (i + 1) % 5 == 0:
# Format denormalized losses if available
denorm_str = "" denorm_str = ""
if batch_candle_loss_denorm: if batch_candle_loss_denorm:
# RMSE values now, much more reasonable
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()] denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
denorm_str = f", Real Price Error: {', '.join(denorm_values)}" denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}") logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
else: else:
logger.warning(f" Batch {i + 1} returned None result - skipping") logger.warning(f" Batch {i + 1} returned None result - skipping")
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation # MEMORY FIX: Explicit cleanup after EVERY batch
# NOTE: We do NOT delete batch tensors here because they are reused across epochs del batch
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+ del result
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
# After optimizer step (not accumulation), force garbage collection # After optimizer step, aggressive cleanup
if not is_accumulation_step: if not is_accumulation_step:
import gc
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
@@ -1891,13 +1860,27 @@ class RealTrainingAdapter:
except torch.cuda.OutOfMemoryError as oom_error: except torch.cuda.OutOfMemoryError as oom_error:
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}") logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
# Aggressive memory cleanup on OOM # Aggressive memory cleanup on OOM
if 'batch' in locals():
del batch
if 'result' in locals():
del result
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
# Reset optimizer state to prevent corruption # Reset optimizer state
trainer.optimizer.zero_grad(set_to_none=True) trainer.optimizer.zero_grad(set_to_none=True)
logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset") logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset")
continue continue
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")
# Cleanup on error
if 'batch' in locals():
del batch
if 'result' in locals():
del result
gc.collect()
continue
except Exception as e: except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}") logger.error(f" Error in batch {i + 1}: {e}")
import traceback import traceback
@@ -1973,30 +1956,23 @@ class RealTrainingAdapter:
except Exception as e: except Exception as e:
logger.warning(f" Failed to save checkpoint: {e}") logger.warning(f" Failed to save checkpoint: {e}")
# Clear CUDA cache after each epoch # MEMORY FIX: Aggressive epoch-level cleanup
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)") logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
session.final_loss = session.current_loss session.final_loss = session.current_loss
session.accuracy = avg_accuracy session.accuracy = avg_accuracy
# Cleanup: Delete batch tensors after all epochs are complete # MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
logger.info(" Cleaning up batch data...") logger.info(" Final memory cleanup...")
for batch in grouped_batches: gc.collect()
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
batch.clear()
grouped_batches.clear()
converted_batches.clear()
# Final memory cleanup
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
import gc torch.cuda.synchronize()
gc.collect()
# Log best checkpoint info # Log best checkpoint info
try: try:

View File

@@ -1307,12 +1307,15 @@ class TradingTransformerTrainer:
candle_losses_detail[tf] = tf_loss.item() candle_losses_detail[tf] = tf_loss.item()
# ALSO calculate denormalized loss for better interpretability # ALSO calculate denormalized loss for better interpretability
# Use RMSE (Root Mean Square Error) instead of MSE for realistic values
if tf in norm_params: if tf in norm_params:
with torch.no_grad(): with torch.no_grad():
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf]) pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
target_denorm = self.denormalize_candle(target_candle, norm_params[tf]) target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
denorm_loss = self.price_criterion(pred_denorm, target_denorm) # Use RMSE instead of MSE to get interpretable dollar values
candle_losses_denorm[tf] = denorm_loss.item() mse = torch.mean((pred_denorm - target_denorm) ** 2)
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
candle_losses_denorm[tf] = rmse.item()
# Average loss across available timeframes # Average loss across available timeframes
if timeframe_losses: if timeframe_losses: