diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index a22962a..f14fdd6 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -1766,124 +1766,93 @@ class RealTrainingAdapter: import torch - # Convert all training samples to transformer format - converted_batches = [] - for i, data in enumerate(training_data): - batch = self._convert_annotation_to_transformer_batch(data) - if batch is not None: - # Repeat based on repetitions parameter - # IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors - repetitions = data.get('repetitions', 1) - for _ in range(repetitions): - # Clone all tensors in the batch to ensure independence - cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v - for k, v in batch.items()} - converted_batches.append(cloned_batch) - else: - logger.warning(f" Failed to convert sample {i+1}") + # MEMORY FIX: Use generator instead of list to avoid accumulating all batches in memory + def batch_generator(): + """Generate batches on-the-fly to avoid memory accumulation""" + for i, data in enumerate(training_data): + batch = self._convert_annotation_to_transformer_batch(data) + if batch is not None: + # Repeat based on repetitions parameter + repetitions = data.get('repetitions', 1) + for _ in range(repetitions): + # Yield batch directly without storing + yield batch + else: + logger.warning(f" Failed to convert sample {i+1}") - if not converted_batches: + # Count total batches for logging + total_batches = sum(data.get('repetitions', 1) for data in training_data + if self._convert_annotation_to_transformer_batch(data) is not None) + + if total_batches == 0: raise Exception("No valid training batches after conversion") - logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches") + logger.info(f" Will generate {total_batches} training batches on-the-fly (memory efficient)") - # Use batch size of 1 to avoid OOM with large sequence lengths - # With 5 timeframes * 100 candles = 500 sequence positions per sample - # Batch size of 1 ensures we don't exceed GPU memory (8GB) - mini_batch_size = 1 # Process one sample at a time to avoid OOM + # MEMORY FIX: Process batches directly from generator, no grouping needed + # Batch size of 1 (single sample) to avoid OOM + logger.info(f" Processing batches individually (batch_size=1) for memory efficiency") - def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']: - combined: Dict[str, 'torch.Tensor'] = {} - keys = batch_list[0].keys() - for key in keys: - values = [b[key] for b in batch_list if b[key] is not None] - # Skip keys where all values are None - if not values: - combined[key] = None - continue - - # Special handling for non-tensor keys (like norm_params which is a dict) - if key == 'norm_params': - # Keep norm_params as a list of dicts (one per sample in batch) - combined[key] = values - continue - - # For tensors, concatenate them - try: - combined[key] = torch.cat(values, dim=0) - except (RuntimeError, TypeError) as concat_error: - # If concatenation fails (e.g., not a tensor), keep as list - logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}") - combined[key] = values - return combined - - grouped_batches: List[Dict[str, torch.Tensor]] = [] - current_group: List[Dict[str, torch.Tensor]] = [] - - for batch in converted_batches: - current_group.append(batch) - if len(current_group) >= mini_batch_size: - grouped_batches.append(_combine_batches(current_group)) - current_group = [] - - if current_group: - grouped_batches.append(_combine_batches(current_group)) - - logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})") - - # Train using train_step for each mini-batch with gradient accumulation - # Accumulate gradients over multiple batches to simulate larger batch size - accumulation_steps = 5 # Accumulate 5 batches before optimizer step + # MEMORY FIX: Train using generator with aggressive memory cleanup + # Reduced accumulation steps from 5 to 2 for less memory usage + accumulation_steps = 2 # Accumulate 2 batches before optimizer step + + import gc for epoch in range(session.total_epochs): epoch_loss = 0.0 epoch_accuracy = 0.0 num_batches = 0 - # Clear CUDA cache before epoch + # MEMORY FIX: Aggressive cleanup before epoch + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + torch.cuda.synchronize() - for i, batch in enumerate(grouped_batches): + # Generate batches fresh for each epoch + for i, batch in enumerate(batch_generator()): try: # Determine if this is an accumulation step or optimizer step is_accumulation_step = (i + 1) % accumulation_steps != 0 - # Call the trainer's train_step method with proper batch format + # Call the trainer's train_step method result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step) if result is not None: - batch_loss = result.get('total_loss', 0.0) - batch_accuracy = result.get('accuracy', 0.0) - batch_candle_accuracy = result.get('candle_accuracy', 0.0) - batch_trend_loss = result.get('trend_loss', 0.0) - batch_candle_loss = result.get('candle_loss', 0.0) + # MEMORY FIX: Detach all tensor values to break computation graph + batch_loss = float(result.get('total_loss', 0.0)) + batch_accuracy = float(result.get('accuracy', 0.0)) + batch_candle_accuracy = float(result.get('candle_accuracy', 0.0)) + batch_trend_loss = float(result.get('trend_loss', 0.0)) + batch_candle_loss = float(result.get('candle_loss', 0.0)) batch_candle_loss_denorm = result.get('candle_loss_denorm', {}) + epoch_loss += batch_loss epoch_accuracy += batch_accuracy num_batches += 1 - # Log first batch and every 5th batch for debugging + # Log first batch and every 5th batch if (i + 1) == 1 or (i + 1) % 5 == 0: - # Format denormalized losses if available denorm_str = "" if batch_candle_loss_denorm: + # RMSE values now, much more reasonable denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()] - denorm_str = f", Real Price Error: {', '.join(denorm_values)}" + denorm_str = f", Real Price RMSE: {', '.join(denorm_values)}" - logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}") + logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}") else: logger.warning(f" Batch {i + 1} returned None result - skipping") - # CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation - # NOTE: We do NOT delete batch tensors here because they are reused across epochs - # Deleting them would cause "At least one timeframe must be provided" error on epoch 2+ + # MEMORY FIX: Explicit cleanup after EVERY batch + del batch + del result + if torch.cuda.is_available(): torch.cuda.empty_cache() - # After optimizer step (not accumulation), force garbage collection + # After optimizer step, aggressive cleanup if not is_accumulation_step: - import gc gc.collect() if torch.cuda.is_available(): torch.cuda.synchronize() @@ -1891,13 +1860,27 @@ class RealTrainingAdapter: except torch.cuda.OutOfMemoryError as oom_error: logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}") # Aggressive memory cleanup on OOM + if 'batch' in locals(): + del batch + if 'result' in locals(): + del result + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - # Reset optimizer state to prevent corruption + # Reset optimizer state trainer.optimizer.zero_grad(set_to_none=True) logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset") continue + except Exception as e: + logger.error(f" Error in batch {i + 1}: {e}") + # Cleanup on error + if 'batch' in locals(): + del batch + if 'result' in locals(): + del result + gc.collect() + continue except Exception as e: logger.error(f" Error in batch {i + 1}: {e}") import traceback @@ -1973,30 +1956,23 @@ class RealTrainingAdapter: except Exception as e: logger.warning(f" Failed to save checkpoint: {e}") - # Clear CUDA cache after each epoch + # MEMORY FIX: Aggressive epoch-level cleanup + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + torch.cuda.synchronize() logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)") session.final_loss = session.current_loss session.accuracy = avg_accuracy - # Cleanup: Delete batch tensors after all epochs are complete - logger.info(" Cleaning up batch data...") - for batch in grouped_batches: - for key in list(batch.keys()): - if isinstance(batch[key], torch.Tensor): - del batch[key] - batch.clear() - grouped_batches.clear() - converted_batches.clear() - - # Final memory cleanup + # MEMORY FIX: Final cleanup (no batch lists to clean since we used generator) + logger.info(" Final memory cleanup...") + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() - import gc - gc.collect() + torch.cuda.synchronize() # Log best checkpoint info try: diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index 665ea37..dfa9379 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -1307,12 +1307,15 @@ class TradingTransformerTrainer: candle_losses_detail[tf] = tf_loss.item() # ALSO calculate denormalized loss for better interpretability + # Use RMSE (Root Mean Square Error) instead of MSE for realistic values if tf in norm_params: with torch.no_grad(): pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf]) target_denorm = self.denormalize_candle(target_candle, norm_params[tf]) - denorm_loss = self.price_criterion(pred_denorm, target_denorm) - candle_losses_denorm[tf] = denorm_loss.item() + # Use RMSE instead of MSE to get interpretable dollar values + mse = torch.mean((pred_denorm - target_denorm) ** 2) + rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability + candle_losses_denorm[tf] = rmse.item() # Average loss across available timeframes if timeframe_losses: