fix T training memory usage (due for more improvement)
This commit is contained in:
@@ -57,6 +57,9 @@ class TradingTransformerConfig:
|
||||
use_deep_attention: bool = True # Deeper attention mechanisms
|
||||
use_residual_connections: bool = True # Enhanced residual connections
|
||||
use_layer_norm_variants: bool = True # Advanced normalization
|
||||
|
||||
# Memory optimization
|
||||
use_gradient_checkpointing: bool = True # Trade compute for memory (saves ~30% memory)
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Sinusoidal positional encoding for transformer"""
|
||||
@@ -638,17 +641,17 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
|
||||
num_tfs = len(timeframe_encodings)
|
||||
|
||||
# Reshape for cross-timeframe attention
|
||||
# [batch, num_tfs, seq_len, d_model] -> [batch, num_tfs * seq_len, d_model]
|
||||
cross_tf_input = stacked_tfs.reshape(batch_size, num_tfs * seq_len, self.config.d_model)
|
||||
# MEMORY EFFICIENT: Process timeframes with shared weights
|
||||
# Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model]
|
||||
# This avoids creating huge concatenated sequences while still processing efficiently
|
||||
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
|
||||
|
||||
# Apply cross-timeframe attention layers
|
||||
# This allows the model to see patterns ACROSS timeframes simultaneously
|
||||
# Apply attention layers (shared across timeframes)
|
||||
for layer in self.cross_timeframe_layers:
|
||||
cross_tf_input = layer(cross_tf_input)
|
||||
batched_tfs = layer(batched_tfs)
|
||||
|
||||
# Reshape back: [batch, num_tfs * seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||
cross_tf_output = cross_tf_input.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||
cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||
|
||||
# Average across timeframes to get unified representation
|
||||
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
||||
@@ -706,10 +709,18 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
else:
|
||||
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
|
||||
|
||||
# Apply transformer layers
|
||||
# Apply transformer layers with optional gradient checkpointing
|
||||
regime_probs_history = []
|
||||
for layer in self.layers:
|
||||
layer_output = layer(x, mask)
|
||||
if self.training and self.config.use_gradient_checkpointing:
|
||||
# Use gradient checkpointing to save memory during training
|
||||
# Trades compute for memory (recomputes activations during backward pass)
|
||||
layer_output = torch.utils.checkpoint.checkpoint(
|
||||
layer, x, mask, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
layer_output = layer(x, mask)
|
||||
|
||||
x = layer_output['output']
|
||||
if layer_output['regime_probs'] is not None:
|
||||
regime_probs_history.append(layer_output['regime_probs'])
|
||||
@@ -1107,6 +1118,11 @@ class TradingTransformerTrainer:
|
||||
# Move model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Mixed precision training disabled - causes dtype mismatches
|
||||
# Can be re-enabled if needed, but requires careful dtype management
|
||||
self.use_amp = False
|
||||
self.scaler = None
|
||||
|
||||
# Optimizer with warmup
|
||||
self.optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
@@ -1136,37 +1152,47 @@ class TradingTransformerTrainer:
|
||||
'learning_rates': []
|
||||
}
|
||||
|
||||
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Single training step"""
|
||||
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
|
||||
"""Single training step with optional gradient accumulation
|
||||
|
||||
Args:
|
||||
batch: Training batch
|
||||
accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation)
|
||||
"""
|
||||
try:
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Only zero gradients if not accumulating
|
||||
if not accumulate_gradients:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Move batch to device WITHOUT cloning to avoid version tracking issues
|
||||
# The detach().clone() was causing gradient computation errors
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Forward pass with multi-timeframe data
|
||||
outputs = self.model(
|
||||
price_data_1s=batch.get('price_data_1s'),
|
||||
price_data_1m=batch.get('price_data_1m'),
|
||||
price_data_1h=batch.get('price_data_1h'),
|
||||
price_data_1d=batch.get('price_data_1d'),
|
||||
btc_data_1m=batch.get('btc_data_1m'),
|
||||
cob_data=batch['cob_data'],
|
||||
tech_data=batch['tech_data'],
|
||||
market_data=batch['market_data'],
|
||||
position_state=batch.get('position_state'),
|
||||
price_data=batch.get('price_data') # Legacy fallback
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
# Use automatic mixed precision (FP16) for memory efficiency
|
||||
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
||||
# Forward pass with multi-timeframe data
|
||||
outputs = self.model(
|
||||
price_data_1s=batch.get('price_data_1s'),
|
||||
price_data_1m=batch.get('price_data_1m'),
|
||||
price_data_1h=batch.get('price_data_1h'),
|
||||
price_data_1d=batch.get('price_data_1d'),
|
||||
btc_data_1m=batch.get('btc_data_1m'),
|
||||
cob_data=batch['cob_data'],
|
||||
tech_data=batch['tech_data'],
|
||||
market_data=batch['market_data'],
|
||||
position_state=batch.get('position_state'),
|
||||
price_data=batch.get('price_data') # Legacy fallback
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
@@ -1199,9 +1225,12 @@ class TradingTransformerTrainer:
|
||||
# Use addition instead of += to avoid inplace operation
|
||||
total_loss = total_loss + 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
# Backward pass with mixed precision scaling
|
||||
try:
|
||||
total_loss.backward()
|
||||
if self.use_amp:
|
||||
self.scaler.scale(total_loss).backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
except RuntimeError as e:
|
||||
if "inplace operation" in str(e):
|
||||
logger.error(f"Inplace operation error during backward pass: {e}")
|
||||
@@ -1216,12 +1245,23 @@ class TradingTransformerTrainer:
|
||||
else:
|
||||
raise
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
# Only clip gradients and step optimizer if not accumulating
|
||||
if not accumulate_gradients:
|
||||
if self.use_amp:
|
||||
# Unscale gradients before clipping
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
# Optimizer step with scaling
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy without gradients
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user