reduce memory usage

This commit is contained in:
Dobromir Popov
2025-11-13 17:34:31 +02:00
parent b0b24f36b2
commit 70e8ede8d3
3 changed files with 163 additions and 24 deletions

View File

@@ -1808,9 +1808,26 @@ class RealTrainingAdapter:
logger.info(f" Converted {len(cached_batches)} batches, cleared source data")
def batch_generator():
"""Yield pre-converted batches (no recreation)"""
"""
Yield pre-converted batches with proper memory management
CRITICAL: Each batch must be cloned and detached to prevent:
1. GPU memory accumulation across epochs
2. Computation graph retention
3. Version tracking issues
"""
for batch in cached_batches:
yield batch
# Clone and detach each tensor in the batch
# This creates a fresh copy without gradient history
cloned_batch = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
# detach() removes from computation graph
# clone() creates new memory (prevents aliasing)
cloned_batch[key] = value.detach().clone()
else:
cloned_batch[key] = value
yield cloned_batch
total_batches = len(cached_batches)
@@ -1823,9 +1840,18 @@ class RealTrainingAdapter:
# Batch size of 1 (single sample) to avoid OOM
logger.info(f" Processing batches individually (batch_size=1) for memory efficiency")
# MEMORY FIX: Train using generator with aggressive memory cleanup
# Reduced accumulation steps from 5 to 2 for less memory usage
accumulation_steps = 2 # Accumulate 2 batches before optimizer step
# MEMORY OPTIMIZATION: Configure gradient accumulation
# Process samples one at a time, accumulate gradients over multiple samples
# This reduces peak memory by ~50% compared to batching
accumulation_steps = max(2, min(5, total_batches)) # 2-5 steps based on data size
logger.info(f" Gradient accumulation: {accumulation_steps} steps")
logger.info(f" Effective batch size: {accumulation_steps} (processed as {accumulation_steps} × batch_size=1)")
# Configure trainer for gradient accumulation
if hasattr(trainer, 'set_gradient_accumulation_steps'):
trainer.set_gradient_accumulation_steps(accumulation_steps)
logger.info(f" Trainer configured for automatic gradient accumulation")
import gc
@@ -1840,14 +1866,16 @@ class RealTrainingAdapter:
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset gradient accumulation counter at start of epoch
if hasattr(trainer, 'reset_gradient_accumulation'):
trainer.reset_gradient_accumulation()
# Generate batches fresh for each epoch
for i, batch in enumerate(batch_generator()):
try:
# Determine if this is an accumulation step or optimizer step
is_accumulation_step = (i + 1) % accumulation_steps != 0
# Call the trainer's train_step method
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
# Trainer now handles gradient accumulation automatically
result = trainer.train_step(batch)
if result is not None:
# MEMORY FIX: Detach all tensor values to break computation graph
@@ -1885,18 +1913,29 @@ class RealTrainingAdapter:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# MEMORY FIX: Explicit cleanup after EVERY batch
# Don't delete batch (it's from cache, reused)
# Delete result dict to free memory
if 'result' in locals():
del result
# Delete the cloned batch (it's a fresh copy, safe to delete)
if 'batch' in locals():
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# Clear CUDA cache after every batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# After optimizer step, aggressive cleanup + memory check
if not is_accumulation_step:
# After optimizer step, aggressive cleanup
# Check if this was an optimizer step (not accumulation)
is_optimizer_step = ((i + 1) % accumulation_steps == 0)
if is_optimizer_step:
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
# CRITICAL: Check memory limit
memory_usage = memory_guard.check_memory(raise_on_limit=True)

View File

@@ -1126,11 +1126,21 @@ class TradingTransformerTrainer:
# Move model to device
self.model.to(self.device)
# MEMORY OPTIMIZATION: Enable gradient checkpointing if configured
# This trades 20% compute for 30-40% memory savings
if config.use_gradient_checkpointing:
logger.info("Enabling gradient checkpointing for memory efficiency")
self._enable_gradient_checkpointing()
# Mixed precision training disabled - causes dtype mismatches
# Can be re-enabled if needed, but requires careful dtype management
self.use_amp = False
self.scaler = None
# GRADIENT ACCUMULATION: Track accumulation state
self.gradient_accumulation_steps = 0
self.current_accumulation_step = 0
# Optimizer with warmup
self.optimizer = optim.AdamW(
model.parameters(),
@@ -1160,6 +1170,53 @@ class TradingTransformerTrainer:
'learning_rates': []
}
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing on transformer layers to save memory"""
try:
# Apply checkpointing to each transformer layer
for layer in self.model.layers:
if hasattr(layer, 'attention'):
# Wrap attention in checkpoint
original_forward = layer.attention.forward
def checkpointed_attention_forward(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
original_forward, *args, **kwargs, use_reentrant=False
)
layer.attention.forward = checkpointed_attention_forward
if hasattr(layer, 'feed_forward'):
# Wrap feed-forward in checkpoint
original_ff_forward = layer.feed_forward.forward
def checkpointed_ff_forward(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
original_ff_forward, *args, **kwargs, use_reentrant=False
)
layer.feed_forward.forward = checkpointed_ff_forward
logger.info("Gradient checkpointing enabled on all transformer layers")
except Exception as e:
logger.warning(f"Failed to enable gradient checkpointing: {e}")
def set_gradient_accumulation_steps(self, steps: int):
"""
Set the number of gradient accumulation steps
Args:
steps: Number of batches to accumulate gradients over before optimizer step
For example, steps=5 means process 5 batches, then update weights
"""
self.gradient_accumulation_steps = steps
self.current_accumulation_step = 0
logger.info(f"Gradient accumulation enabled: {steps} steps")
def reset_gradient_accumulation(self):
"""Reset gradient accumulation counter"""
self.current_accumulation_step = 0
@staticmethod
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
@@ -1214,19 +1271,47 @@ class TradingTransformerTrainer:
Args:
batch: Training batch
accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation)
This is DEPRECATED - use gradient_accumulation_steps instead
Returns:
Dictionary with loss and accuracy metrics
"""
try:
self.model.train()
# Only zero gradients if not accumulating
# Use set_to_none=True for better memory efficiency
if not accumulate_gradients:
# GRADIENT ACCUMULATION: Determine if this is an accumulation step
# If gradient_accumulation_steps is set, use automatic accumulation
# Otherwise, fall back to manual accumulate_gradients flag
if self.gradient_accumulation_steps > 0:
is_accumulation_step = (self.current_accumulation_step < self.gradient_accumulation_steps - 1)
self.current_accumulation_step += 1
# Reset counter after full accumulation cycle
if self.current_accumulation_step >= self.gradient_accumulation_steps:
self.current_accumulation_step = 0
else:
is_accumulation_step = accumulate_gradients
# Only zero gradients at the start of accumulation cycle
# Use set_to_none=True for better memory efficiency (saves ~5% memory)
if not is_accumulation_step or self.current_accumulation_step == 1:
self.optimizer.zero_grad(set_to_none=True)
# Move batch to device WITHOUT cloning to avoid version tracking issues
# The detach().clone() was causing gradient computation errors
batch = {k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Move batch to device and DELETE original CPU tensors to prevent memory leak
# CRITICAL: Store original keys to delete CPU tensors after moving to GPU
batch_gpu = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Move to device (creates GPU copy)
batch_gpu[k] = v.to(self.device, non_blocking=True)
# Delete CPU tensor immediately to free memory
del batch[k]
else:
batch_gpu[k] = v
# Replace batch with GPU version
batch = batch_gpu
del batch_gpu
# Use automatic mixed precision (FP16) for memory efficiency
# Support both CUDA and ROCm (AMD) devices
@@ -1330,8 +1415,11 @@ class TradingTransformerTrainer:
# CRITICAL FIX: Scale loss for gradient accumulation
# This prevents gradient explosion when accumulating over multiple batches
if accumulate_gradients:
# Assume accumulation over 5 steps (should match training loop)
# The loss is averaged over accumulation steps so gradients sum correctly
if self.gradient_accumulation_steps > 0:
total_loss = total_loss / self.gradient_accumulation_steps
elif accumulate_gradients:
# Legacy fallback - assume 5 steps if not specified
total_loss = total_loss / 5.0
# Add confidence loss if available
@@ -1388,8 +1476,8 @@ class TradingTransformerTrainer:
else:
raise
# Only clip gradients and step optimizer if not accumulating
if not accumulate_gradients:
# Only clip gradients and step optimizer at the end of accumulation cycle
if not is_accumulation_step:
if self.use_amp:
# Unscale gradients before clipping
self.scaler.unscale_(self.optimizer)
@@ -1406,6 +1494,10 @@ class TradingTransformerTrainer:
self.scheduler.step()
# Log gradient accumulation completion
if self.gradient_accumulation_steps > 0:
logger.debug(f"Gradient accumulation cycle complete ({self.gradient_accumulation_steps} steps)")
# Calculate accuracy without gradients
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
with torch.no_grad():
@@ -1487,6 +1579,14 @@ class TradingTransformerTrainer:
# CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
# Delete batch tensors (GPU copies)
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -29,7 +29,7 @@ class MemoryGuard:
def __init__(self,
max_memory_gb: float = 50.0,
warning_threshold: float = 0.85,
check_interval: float = 5.0,
check_interval: float = 2.0,
auto_cleanup: bool = True):
"""
Initialize Memory Guard