fix mem leak and train loss
This commit is contained in:
@@ -1766,124 +1766,93 @@ class RealTrainingAdapter:
|
||||
|
||||
import torch
|
||||
|
||||
# Convert all training samples to transformer format
|
||||
converted_batches = []
|
||||
for i, data in enumerate(training_data):
|
||||
batch = self._convert_annotation_to_transformer_batch(data)
|
||||
if batch is not None:
|
||||
# Repeat based on repetitions parameter
|
||||
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors
|
||||
repetitions = data.get('repetitions', 1)
|
||||
for _ in range(repetitions):
|
||||
# Clone all tensors in the batch to ensure independence
|
||||
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
converted_batches.append(cloned_batch)
|
||||
else:
|
||||
logger.warning(f" Failed to convert sample {i+1}")
|
||||
# MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory
|
||||
def batch_generator():
|
||||
"""Generate batches on-the-fly to avoid memory accumulation"""
|
||||
for i, data in enumerate(training_data):
|
||||
batch = self._convert_annotation_to_transformer_batch(data)
|
||||
if batch is not None:
|
||||
# Repeat based on repetitions parameter
|
||||
repetitions = data.get('repetitions', 1)
|
||||
for _ in range(repetitions):
|
||||
# Yield batch directly without storing
|
||||
yield batch
|
||||
else:
|
||||
logger.warning(f" Failed to convert sample {i+1}")
|
||||
|
||||
if not converted_batches:
|
||||
# Count total batches for logging
|
||||
total_batches = sum(data.get('repetitions', 1) for data in training_data
|
||||
if self._convert_annotation_to_transformer_batch(data) is not None)
|
||||
|
||||
if total_batches == 0:
|
||||
raise Exception("No valid training batches after conversion")
|
||||
|
||||
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
|
||||
logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)")
|
||||
|
||||
# Use batch size of 1 to avoid OOM with large sequence lengths
|
||||
# With 5 timeframes * 100 candles = 500 sequence positions per sample
|
||||
# Batch size of 1 ensures we don't exceed GPU memory (8GB)
|
||||
mini_batch_size = 1 # Process one sample at a time to avoid OOM
|
||||
# MEMORY FIX: Process batches directly from generator, no grouping needed
|
||||
# Batch size of 1 (single sample) to avoid OOM
|
||||
logger.info(f" Processing batches individually (batch_size=1) for memory efficiency")
|
||||
|
||||
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
|
||||
combined: Dict[str, 'torch.Tensor'] = {}
|
||||
keys = batch_list[0].keys()
|
||||
for key in keys:
|
||||
values = [b[key] for b in batch_list if b[key] is not None]
|
||||
# Skip keys where all values are None
|
||||
if not values:
|
||||
combined[key] = None
|
||||
continue
|
||||
|
||||
# Special handling for non-tensor keys (like norm_params which is a dict)
|
||||
if key == 'norm_params':
|
||||
# Keep norm_params as a list of dicts (one per sample in batch)
|
||||
combined[key] = values
|
||||
continue
|
||||
|
||||
# For tensors, concatenate them
|
||||
try:
|
||||
combined[key] = torch.cat(values, dim=0)
|
||||
except (RuntimeError, TypeError) as concat_error:
|
||||
# If concatenation fails (e.g., not a tensor), keep as list
|
||||
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
|
||||
combined[key] = values
|
||||
return combined
|
||||
|
||||
grouped_batches: List[Dict[str, torch.Tensor]] = []
|
||||
current_group: List[Dict[str, torch.Tensor]] = []
|
||||
|
||||
for batch in converted_batches:
|
||||
current_group.append(batch)
|
||||
if len(current_group) >= mini_batch_size:
|
||||
grouped_batches.append(_combine_batches(current_group))
|
||||
current_group = []
|
||||
|
||||
if current_group:
|
||||
grouped_batches.append(_combine_batches(current_group))
|
||||
|
||||
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
|
||||
|
||||
# Train using train_step for each mini-batch with gradient accumulation
|
||||
# Accumulate gradients over multiple batches to simulate larger batch size
|
||||
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
|
||||
# MEMORY FIX: Train using generator with aggressive memory cleanup
|
||||
# Reduced accumulation steps from 5 to 2 for less memory usage
|
||||
accumulation_steps = 2 # Accumulate 2 batches before optimizer step
|
||||
|
||||
import gc
|
||||
|
||||
for epoch in range(session.total_epochs):
|
||||
epoch_loss = 0.0
|
||||
epoch_accuracy = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Clear CUDA cache before epoch
|
||||
# MEMORY FIX: Aggressive cleanup before epoch
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for i, batch in enumerate(grouped_batches):
|
||||
# Generate batches fresh for each epoch
|
||||
for i, batch in enumerate(batch_generator()):
|
||||
try:
|
||||
# Determine if this is an accumulation step or optimizer step
|
||||
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||||
|
||||
# Call the trainer's train_step method with proper batch format
|
||||
# Call the trainer's train_step method
|
||||
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
|
||||
|
||||
if result is not None:
|
||||
batch_loss = result.get('total_loss', 0.0)
|
||||
batch_accuracy = result.get('accuracy', 0.0)
|
||||
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
||||
batch_trend_loss = result.get('trend_loss', 0.0)
|
||||
batch_candle_loss = result.get('candle_loss', 0.0)
|
||||
# MEMORY FIX: Detach all tensor values to break computation graph
|
||||
batch_loss = float(result.get('total_loss', 0.0))
|
||||
batch_accuracy = float(result.get('accuracy', 0.0))
|
||||
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
|
||||
batch_trend_loss = float(result.get('trend_loss', 0.0))
|
||||
batch_candle_loss = float(result.get('candle_loss', 0.0))
|
||||
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||
|
||||
epoch_loss += batch_loss
|
||||
epoch_accuracy += batch_accuracy
|
||||
num_batches += 1
|
||||
|
||||
# Log first batch and every 5th batch for debugging
|
||||
# Log first batch and every 5th batch
|
||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||
# Format denormalized losses if available
|
||||
denorm_str = ""
|
||||
if batch_candle_loss_denorm:
|
||||
# RMSE values now, much more reasonable
|
||||
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||
denorm_str = f", Real Price Error: {', '.join(denorm_values)}"
|
||||
denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}"
|
||||
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
||||
# NOTE: We do NOT delete batch tensors here because they are reused across epochs
|
||||
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+
|
||||
# MEMORY FIX: Explicit cleanup after EVERY batch
|
||||
del batch
|
||||
del result
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# After optimizer step (not accumulation), force garbage collection
|
||||
# After optimizer step, aggressive cleanup
|
||||
if not is_accumulation_step:
|
||||
import gc
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
@@ -1891,13 +1860,27 @@ class RealTrainingAdapter:
|
||||
except torch.cuda.OutOfMemoryError as oom_error:
|
||||
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
|
||||
# Aggressive memory cleanup on OOM
|
||||
if 'batch' in locals():
|
||||
del batch
|
||||
if 'result' in locals():
|
||||
del result
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
# Reset optimizer state to prevent corruption
|
||||
# Reset optimizer state
|
||||
trainer.optimizer.zero_grad(set_to_none=True)
|
||||
logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f" Error in batch {i + 1}: {e}")
|
||||
# Cleanup on error
|
||||
if 'batch' in locals():
|
||||
del batch
|
||||
if 'result' in locals():
|
||||
del result
|
||||
gc.collect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f" Error in batch {i + 1}: {e}")
|
||||
import traceback
|
||||
@@ -1973,30 +1956,23 @@ class RealTrainingAdapter:
|
||||
except Exception as e:
|
||||
logger.warning(f" Failed to save checkpoint: {e}")
|
||||
|
||||
# Clear CUDA cache after each epoch
|
||||
# MEMORY FIX: Aggressive epoch-level cleanup
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = avg_accuracy
|
||||
|
||||
# Cleanup: Delete batch tensors after all epochs are complete
|
||||
logger.info(" Cleaning up batch data...")
|
||||
for batch in grouped_batches:
|
||||
for key in list(batch.keys()):
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
del batch[key]
|
||||
batch.clear()
|
||||
grouped_batches.clear()
|
||||
converted_batches.clear()
|
||||
|
||||
# Final memory cleanup
|
||||
# MEMORY FIX: Final cleanup (no batch lists to clean since we used generator)
|
||||
logger.info(" Final memory cleanup...")
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Log best checkpoint info
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user