fix mem leak and train loss
This commit is contained in:
@@ -1766,124 +1766,93 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Convert all training samples to transformer format
|
# MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
|
||||||
converted_batches = []
|
def batch_generator():
|
||||||
for i, data in enumerate(training_data):
|
"""Generate batches on-the-fly to avoid memory accumulation"""
|
||||||
batch = self._convert_annotation_to_transformer_batch(data)
|
for i, data in enumerate(training_data):
|
||||||
if batch is not None:
|
batch = self._convert_annotation_to_transformer_batch(data)
|
||||||
# Repeat based on repetitions parameter
|
if batch is not None:
|
||||||
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors
|
# Repeat based on repetitions parameter
|
||||||
repetitions = data.get('repetitions', 1)
|
repetitions = data.get('repetitions', 1)
|
||||||
for _ in range(repetitions):
|
for _ in range(repetitions):
|
||||||
# Clone all tensors in the batch to ensure independence
|
# Yield batch directly without storing
|
||||||
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
|
yield batch
|
||||||
for k, v in batch.items()}
|
else:
|
||||||
converted_batches.append(cloned_batch)
|
logger.warning(f" Failed to convert sample {i+1}")
|
||||||
else:
|
|
||||||
logger.warning(f" Failed to convert sample {i+1}")
|
|
||||||
|
|
||||||
if not converted_batches:
|
# Count total batches for logging
|
||||||
|
total_batches = sum(data.get('repetitions', 1) for data in training_data
|
||||||
|
if self._convert_annotation_to_transformer_batch(data) is not None)
|
||||||
|
|
||||||
|
if total_batches == 0:
|
||||||
raise Exception("No valid training batches after conversion")
|
raise Exception("No valid training batches after conversion")
|
||||||
|
|
||||||
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
|
logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
|
||||||
|
|
||||||
# Use batch size of 1 to avoid OOM with large sequence lengths
|
# MEMORY FIX: Process batches directly from generator, no grouping needed
|
||||||
# With 5 timeframes * 100 candles = 500 sequence positions per sample
|
# Batch size of 1 (single sample) to avoid OOM
|
||||||
# Batch size of 1 ensures we don't exceed GPU memory (8GB)
|
logger.info(f" Processing batches individually (batch_size=1) for memory efficiency")
|
||||||
mini_batch_size = 1 # Process one sample at a time to avoid OOM
|
|
||||||
|
|
||||||
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
|
# MEMORY FIX: Train using generator with aggressive memory cleanup
|
||||||
combined: Dict[str, 'torch.Tensor'] = {}
|
# Reduced accumulation steps from 5 to 2 for less memory usage
|
||||||
keys = batch_list[0].keys()
|
accumulation_steps = 2 # Accumulate 2 batches before optimizer step
|
||||||
for key in keys:
|
|
||||||
values = [b[key] for b in batch_list if b[key] is not None]
|
|
||||||
# Skip keys where all values are None
|
|
||||||
if not values:
|
|
||||||
combined[key] = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Special handling for non-tensor keys (like norm_params which is a dict)
|
import gc
|
||||||
if key == 'norm_params':
|
|
||||||
# Keep norm_params as a list of dicts (one per sample in batch)
|
|
||||||
combined[key] = values
|
|
||||||
continue
|
|
||||||
|
|
||||||
# For tensors, concatenate them
|
|
||||||
try:
|
|
||||||
combined[key] = torch.cat(values, dim=0)
|
|
||||||
except (RuntimeError, TypeError) as concat_error:
|
|
||||||
# If concatenation fails (e.g., not a tensor), keep as list
|
|
||||||
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
|
|
||||||
combined[key] = values
|
|
||||||
return combined
|
|
||||||
|
|
||||||
grouped_batches: List[Dict[str, torch.Tensor]] = []
|
|
||||||
current_group: List[Dict[str, torch.Tensor]] = []
|
|
||||||
|
|
||||||
for batch in converted_batches:
|
|
||||||
current_group.append(batch)
|
|
||||||
if len(current_group) >= mini_batch_size:
|
|
||||||
grouped_batches.append(_combine_batches(current_group))
|
|
||||||
current_group = []
|
|
||||||
|
|
||||||
if current_group:
|
|
||||||
grouped_batches.append(_combine_batches(current_group))
|
|
||||||
|
|
||||||
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
|
|
||||||
|
|
||||||
# Train using train_step for each mini-batch with gradient accumulation
|
|
||||||
# Accumulate gradients over multiple batches to simulate larger batch size
|
|
||||||
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
|
|
||||||
|
|
||||||
for epoch in range(session.total_epochs):
|
for epoch in range(session.total_epochs):
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
epoch_accuracy = 0.0
|
epoch_accuracy = 0.0
|
||||||
num_batches = 0
|
num_batches = 0
|
||||||
|
|
||||||
# Clear CUDA cache before epoch
|
# MEMORY FIX: Aggressive cleanup before epoch
|
||||||
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
for i, batch in enumerate(grouped_batches):
|
# Generate batches fresh for each epoch
|
||||||
|
for i, batch in enumerate(batch_generator()):
|
||||||
try:
|
try:
|
||||||
# Determine if this is an accumulation step or optimizer step
|
# Determine if this is an accumulation step or optimizer step
|
||||||
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||||||
|
|
||||||
# Call the trainer's train_step method with proper batch format
|
# Call the trainer's train_step method
|
||||||
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
|
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
batch_loss = result.get('total_loss', 0.0)
|
# MEMORY FIX: Detach all tensor values to break computation graph
|
||||||
batch_accuracy = result.get('accuracy', 0.0)
|
batch_loss = float(result.get('total_loss', 0.0))
|
||||||
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
batch_accuracy = float(result.get('accuracy', 0.0))
|
||||||
batch_trend_loss = result.get('trend_loss', 0.0)
|
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
|
||||||
batch_candle_loss = result.get('candle_loss', 0.0)
|
batch_trend_loss = float(result.get('trend_loss', 0.0))
|
||||||
|
batch_candle_loss = float(result.get('candle_loss', 0.0))
|
||||||
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||||
|
|
||||||
epoch_loss += batch_loss
|
epoch_loss += batch_loss
|
||||||
epoch_accuracy += batch_accuracy
|
epoch_accuracy += batch_accuracy
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
# Log first batch and every 5th batch for debugging
|
# Log first batch and every 5th batch
|
||||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||||
# Format denormalized losses if available
|
|
||||||
denorm_str = ""
|
denorm_str = ""
|
||||||
if batch_candle_loss_denorm:
|
if batch_candle_loss_denorm:
|
||||||
|
# RMSE values now, much more reasonable
|
||||||
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||||
denorm_str = f", Real Price Error: {', '.join(denorm_values)}"
|
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
|
||||||
|
|
||||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||||
|
|
||||||
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
# MEMORY FIX: Explicit cleanup after EVERY batch
|
||||||
# NOTE: We do NOT delete batch tensors here because they are reused across epochs
|
del batch
|
||||||
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+
|
del result
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# After optimizer step (not accumulation), force garbage collection
|
# After optimizer step, aggressive cleanup
|
||||||
if not is_accumulation_step:
|
if not is_accumulation_step:
|
||||||
import gc
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -1891,13 +1860,27 @@ class RealTrainingAdapter:
|
|||||||
except torch.cuda.OutOfMemoryError as oom_error:
|
except torch.cuda.OutOfMemoryError as oom_error:
|
||||||
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
|
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
|
||||||
# Aggressive memory cleanup on OOM
|
# Aggressive memory cleanup on OOM
|
||||||
|
if 'batch' in locals():
|
||||||
|
del batch
|
||||||
|
if 'result' in locals():
|
||||||
|
del result
|
||||||
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# Reset optimizer state to prevent corruption
|
# Reset optimizer state
|
||||||
trainer.optimizer.zero_grad(set_to_none=True)
|
trainer.optimizer.zero_grad(set_to_none=True)
|
||||||
logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset")
|
logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset")
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f" Error in batch {i + 1}: {e}")
|
||||||
|
# Cleanup on error
|
||||||
|
if 'batch' in locals():
|
||||||
|
del batch
|
||||||
|
if 'result' in locals():
|
||||||
|
del result
|
||||||
|
gc.collect()
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f" Error in batch {i + 1}: {e}")
|
logger.error(f" Error in batch {i + 1}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
@@ -1973,30 +1956,23 @@ class RealTrainingAdapter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f" Failed to save checkpoint: {e}")
|
logger.warning(f" Failed to save checkpoint: {e}")
|
||||||
|
|
||||||
# Clear CUDA cache after each epoch
|
# MEMORY FIX: Aggressive epoch-level cleanup
|
||||||
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
||||||
|
|
||||||
session.final_loss = session.current_loss
|
session.final_loss = session.current_loss
|
||||||
session.accuracy = avg_accuracy
|
session.accuracy = avg_accuracy
|
||||||
|
|
||||||
# Cleanup: Delete batch tensors after all epochs are complete
|
# MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
|
||||||
logger.info(" Cleaning up batch data...")
|
logger.info(" Final memory cleanup...")
|
||||||
for batch in grouped_batches:
|
gc.collect()
|
||||||
for key in list(batch.keys()):
|
|
||||||
if isinstance(batch[key], torch.Tensor):
|
|
||||||
del batch[key]
|
|
||||||
batch.clear()
|
|
||||||
grouped_batches.clear()
|
|
||||||
converted_batches.clear()
|
|
||||||
|
|
||||||
# Final memory cleanup
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
import gc
|
torch.cuda.synchronize()
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Log best checkpoint info
|
# Log best checkpoint info
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1307,12 +1307,15 @@ class TradingTransformerTrainer:
|
|||||||
candle_losses_detail[tf] = tf_loss.item()
|
candle_losses_detail[tf] = tf_loss.item()
|
||||||
|
|
||||||
# ALSO calculate denormalized loss for better interpretability
|
# ALSO calculate denormalized loss for better interpretability
|
||||||
|
# Use RMSE (Root Mean Square Error) instead of MSE for realistic values
|
||||||
if tf in norm_params:
|
if tf in norm_params:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
|
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
|
||||||
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
|
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
|
||||||
denorm_loss = self.price_criterion(pred_denorm, target_denorm)
|
# Use RMSE instead of MSE to get interpretable dollar values
|
||||||
candle_losses_denorm[tf] = denorm_loss.item()
|
mse = torch.mean((pred_denorm - target_denorm) ** 2)
|
||||||
|
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
|
||||||
|
candle_losses_denorm[tf] = rmse.item()
|
||||||
|
|
||||||
# Average loss across available timeframes
|
# Average loss across available timeframes
|
||||||
if timeframe_losses:
|
if timeframe_losses:
|
||||||
|
|||||||
Reference in New Issue
Block a user