reduce memory usage

This commit is contained in:
Dobromir Popov
2025-11-13 17:34:31 +02:00
parent b0b24f36b2
commit 70e8ede8d3
3 changed files with 163 additions and 24 deletions

View File

@@ -1126,11 +1126,21 @@ class TradingTransformerTrainer:
# Move model to device
self.model.to(self.device)
# MEMORY OPTIMIZATION: Enable gradient checkpointing if configured
# This trades 20% compute for 30-40% memory savings
if config.use_gradient_checkpointing:
logger.info("Enabling gradient checkpointing for memory efficiency")
self._enable_gradient_checkpointing()
# Mixed precision training disabled - causes dtype mismatches
# Can be re-enabled if needed, but requires careful dtype management
self.use_amp = False
self.scaler = None
# GRADIENT ACCUMULATION: Track accumulation state
self.gradient_accumulation_steps = 0
self.current_accumulation_step = 0
# Optimizer with warmup
self.optimizer = optim.AdamW(
model.parameters(),
@@ -1160,6 +1170,53 @@ class TradingTransformerTrainer:
'learning_rates': []
}
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing on transformer layers to save memory"""
try:
# Apply checkpointing to each transformer layer
for layer in self.model.layers:
if hasattr(layer, 'attention'):
# Wrap attention in checkpoint
original_forward = layer.attention.forward
def checkpointed_attention_forward(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
original_forward, *args, **kwargs, use_reentrant=False
)
layer.attention.forward = checkpointed_attention_forward
if hasattr(layer, 'feed_forward'):
# Wrap feed-forward in checkpoint
original_ff_forward = layer.feed_forward.forward
def checkpointed_ff_forward(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
original_ff_forward, *args, **kwargs, use_reentrant=False
)
layer.feed_forward.forward = checkpointed_ff_forward
logger.info("Gradient checkpointing enabled on all transformer layers")
except Exception as e:
logger.warning(f"Failed to enable gradient checkpointing: {e}")
def set_gradient_accumulation_steps(self, steps: int):
"""
Set the number of gradient accumulation steps
Args:
steps: Number of batches to accumulate gradients over before optimizer step
For example, steps=5 means process 5 batches, then update weights
"""
self.gradient_accumulation_steps = steps
self.current_accumulation_step = 0
logger.info(f"Gradient accumulation enabled: {steps} steps")
def reset_gradient_accumulation(self):
"""Reset gradient accumulation counter"""
self.current_accumulation_step = 0
@staticmethod
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
@@ -1214,19 +1271,47 @@ class TradingTransformerTrainer:
Args:
batch: Training batch
accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation)
This is DEPRECATED - use gradient_accumulation_steps instead
Returns:
Dictionary with loss and accuracy metrics
"""
try:
self.model.train()
# Only zero gradients if not accumulating
# Use set_to_none=True for better memory efficiency
if not accumulate_gradients:
# GRADIENT ACCUMULATION: Determine if this is an accumulation step
# If gradient_accumulation_steps is set, use automatic accumulation
# Otherwise, fall back to manual accumulate_gradients flag
if self.gradient_accumulation_steps > 0:
is_accumulation_step = (self.current_accumulation_step < self.gradient_accumulation_steps - 1)
self.current_accumulation_step += 1
# Reset counter after full accumulation cycle
if self.current_accumulation_step >= self.gradient_accumulation_steps:
self.current_accumulation_step = 0
else:
is_accumulation_step = accumulate_gradients
# Only zero gradients at the start of accumulation cycle
# Use set_to_none=True for better memory efficiency (saves ~5% memory)
if not is_accumulation_step or self.current_accumulation_step == 1:
self.optimizer.zero_grad(set_to_none=True)
# Move batch to device WITHOUT cloning to avoid version tracking issues
# The detach().clone() was causing gradient computation errors
batch = {k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Move batch to device and DELETE original CPU tensors to prevent memory leak
# CRITICAL: Store original keys to delete CPU tensors after moving to GPU
batch_gpu = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Move to device (creates GPU copy)
batch_gpu[k] = v.to(self.device, non_blocking=True)
# Delete CPU tensor immediately to free memory
del batch[k]
else:
batch_gpu[k] = v
# Replace batch with GPU version
batch = batch_gpu
del batch_gpu
# Use automatic mixed precision (FP16) for memory efficiency
# Support both CUDA and ROCm (AMD) devices
@@ -1330,8 +1415,11 @@ class TradingTransformerTrainer:
# CRITICAL FIX: Scale loss for gradient accumulation
# This prevents gradient explosion when accumulating over multiple batches
if accumulate_gradients:
# Assume accumulation over 5 steps (should match training loop)
# The loss is averaged over accumulation steps so gradients sum correctly
if self.gradient_accumulation_steps > 0:
total_loss = total_loss / self.gradient_accumulation_steps
elif accumulate_gradients:
# Legacy fallback - assume 5 steps if not specified
total_loss = total_loss / 5.0
# Add confidence loss if available
@@ -1388,8 +1476,8 @@ class TradingTransformerTrainer:
else:
raise
# Only clip gradients and step optimizer if not accumulating
if not accumulate_gradients:
# Only clip gradients and step optimizer at the end of accumulation cycle
if not is_accumulation_step:
if self.use_amp:
# Unscale gradients before clipping
self.scaler.unscale_(self.optimizer)
@@ -1405,6 +1493,10 @@ class TradingTransformerTrainer:
self.optimizer.step()
self.scheduler.step()
# Log gradient accumulation completion
if self.gradient_accumulation_steps > 0:
logger.debug(f"Gradient accumulation cycle complete ({self.gradient_accumulation_steps} steps)")
# Calculate accuracy without gradients
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
@@ -1487,6 +1579,14 @@ class TradingTransformerTrainer:
# CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
# Delete batch tensors (GPU copies)
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()