fix T training memory usage (due for more improvement)
This commit is contained in:
@@ -336,7 +336,9 @@ class RealTrainingAdapter:
|
|||||||
# Get training config
|
# Get training config
|
||||||
training_config = test_case.get('training_config', {})
|
training_config = test_case.get('training_config', {})
|
||||||
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||||
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per batch
|
# Reduce sequence length to avoid OOM - 200 candles is more reasonable
|
||||||
|
# With 5 timeframes, this gives 1000 total positions vs 3000 with 600 candles
|
||||||
|
candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch
|
||||||
|
|
||||||
# Determine secondary symbol based on primary symbol
|
# Determine secondary symbol based on primary symbol
|
||||||
# ETH/SOL -> BTC, BTC -> ETH
|
# ETH/SOL -> BTC, BTC -> ETH
|
||||||
@@ -1183,8 +1185,13 @@ class RealTrainingAdapter:
|
|||||||
timeframes = market_state.get('timeframes', {})
|
timeframes = market_state.get('timeframes', {})
|
||||||
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
||||||
|
|
||||||
# Target sequence length for all timeframes
|
# Target sequence length - use actual data length (typically 200 candles)
|
||||||
target_seq_len = 600
|
# Find the first available timeframe to determine sequence length
|
||||||
|
target_seq_len = 200 # Default
|
||||||
|
for tf_data in timeframes.values():
|
||||||
|
if tf_data and 'close' in tf_data and len(tf_data['close']) > 0:
|
||||||
|
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200 to avoid OOM
|
||||||
|
break
|
||||||
|
|
||||||
# Extract each timeframe (returns None if not available)
|
# Extract each timeframe (returns None if not available)
|
||||||
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
||||||
@@ -1219,8 +1226,8 @@ class RealTrainingAdapter:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Create placeholder COB data (zeros if not available)
|
# Create placeholder COB data (zeros if not available)
|
||||||
# COB data shape: [1, 600, 100] to match new sequence length
|
# COB data shape: [1, target_seq_len, 100] to match sequence length
|
||||||
cob_data = torch.zeros(1, 600, 100, dtype=torch.float32)
|
cob_data = torch.zeros(1, target_seq_len, 100, dtype=torch.float32)
|
||||||
|
|
||||||
# Create technical indicators from reference timeframe
|
# Create technical indicators from reference timeframe
|
||||||
tech_features = []
|
tech_features = []
|
||||||
@@ -1487,9 +1494,10 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
|
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
|
||||||
|
|
||||||
# Group single-sample batches into mini-batches for efficient training
|
# Use batch size of 1 to avoid OOM with large sequence lengths
|
||||||
# Small batch size (5) for better gradient updates with limited training data
|
# With 5 timeframes * 600 candles = 3000 sequence positions per sample,
|
||||||
mini_batch_size = 5 # Small batches work better with ~255 samples
|
# even batch_size=5 causes 15,000 positions which is too large for GPU
|
||||||
|
mini_batch_size = 1 # Process one sample at a time to avoid OOM
|
||||||
|
|
||||||
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
|
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
|
||||||
combined: Dict[str, 'torch.Tensor'] = {}
|
combined: Dict[str, 'torch.Tensor'] = {}
|
||||||
@@ -1521,7 +1529,10 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
|
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
|
||||||
|
|
||||||
# Train using train_step for each mini-batch
|
# Train using train_step for each mini-batch with gradient accumulation
|
||||||
|
# Accumulate gradients over multiple batches to simulate larger batch size
|
||||||
|
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
|
||||||
|
|
||||||
for epoch in range(session.total_epochs):
|
for epoch in range(session.total_epochs):
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
epoch_accuracy = 0.0
|
epoch_accuracy = 0.0
|
||||||
@@ -1529,8 +1540,11 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
for i, batch in enumerate(grouped_batches):
|
for i, batch in enumerate(grouped_batches):
|
||||||
try:
|
try:
|
||||||
|
# Determine if this is an accumulation step or optimizer step
|
||||||
|
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||||||
|
|
||||||
# Call the trainer's train_step method with proper batch format
|
# Call the trainer's train_step method with proper batch format
|
||||||
result = trainer.train_step(batch)
|
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
batch_loss = result.get('total_loss', 0.0)
|
batch_loss = result.get('total_loss', 0.0)
|
||||||
@@ -1539,14 +1553,14 @@ class RealTrainingAdapter:
|
|||||||
epoch_accuracy += batch_accuracy
|
epoch_accuracy += batch_accuracy
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
# Log first batch and every 100th batch for debugging
|
# Log first batch and every 10th batch for debugging
|
||||||
if (i + 1) == 1 or (i + 1) % 100 == 0:
|
if (i + 1) == 1 or (i + 1) % 10 == 0:
|
||||||
logger.info(f" Batch {i + 1}/{len(converted_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
|
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||||
|
|
||||||
# Clear CUDA cache periodically to prevent memory leak
|
# Clear CUDA cache after optimizer step (not accumulation step)
|
||||||
if torch.cuda.is_available() and (i + 1) % 5 == 0:
|
if torch.cuda.is_available() and not is_accumulation_step:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ class TradingTransformerConfig:
|
|||||||
use_residual_connections: bool = True # Enhanced residual connections
|
use_residual_connections: bool = True # Enhanced residual connections
|
||||||
use_layer_norm_variants: bool = True # Advanced normalization
|
use_layer_norm_variants: bool = True # Advanced normalization
|
||||||
|
|
||||||
|
# Memory optimization
|
||||||
|
use_gradient_checkpointing: bool = True # Trade compute for memory (saves ~30% memory)
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
"""Sinusoidal positional encoding for transformer"""
|
"""Sinusoidal positional encoding for transformer"""
|
||||||
|
|
||||||
@@ -638,17 +641,17 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
|
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
|
||||||
num_tfs = len(timeframe_encodings)
|
num_tfs = len(timeframe_encodings)
|
||||||
|
|
||||||
# Reshape for cross-timeframe attention
|
# MEMORY EFFICIENT: Process timeframes with shared weights
|
||||||
# [batch, num_tfs, seq_len, d_model] -> [batch, num_tfs * seq_len, d_model]
|
# Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model]
|
||||||
cross_tf_input = stacked_tfs.reshape(batch_size, num_tfs * seq_len, self.config.d_model)
|
# This avoids creating huge concatenated sequences while still processing efficiently
|
||||||
|
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
|
||||||
|
|
||||||
# Apply cross-timeframe attention layers
|
# Apply attention layers (shared across timeframes)
|
||||||
# This allows the model to see patterns ACROSS timeframes simultaneously
|
|
||||||
for layer in self.cross_timeframe_layers:
|
for layer in self.cross_timeframe_layers:
|
||||||
cross_tf_input = layer(cross_tf_input)
|
batched_tfs = layer(batched_tfs)
|
||||||
|
|
||||||
# Reshape back: [batch, num_tfs * seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||||
cross_tf_output = cross_tf_input.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||||
|
|
||||||
# Average across timeframes to get unified representation
|
# Average across timeframes to get unified representation
|
||||||
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
||||||
@@ -706,10 +709,18 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
|
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
|
||||||
|
|
||||||
# Apply transformer layers
|
# Apply transformer layers with optional gradient checkpointing
|
||||||
regime_probs_history = []
|
regime_probs_history = []
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
|
if self.training and self.config.use_gradient_checkpointing:
|
||||||
|
# Use gradient checkpointing to save memory during training
|
||||||
|
# Trades compute for memory (recomputes activations during backward pass)
|
||||||
|
layer_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
layer, x, mask, use_reentrant=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
layer_output = layer(x, mask)
|
layer_output = layer(x, mask)
|
||||||
|
|
||||||
x = layer_output['output']
|
x = layer_output['output']
|
||||||
if layer_output['regime_probs'] is not None:
|
if layer_output['regime_probs'] is not None:
|
||||||
regime_probs_history.append(layer_output['regime_probs'])
|
regime_probs_history.append(layer_output['regime_probs'])
|
||||||
@@ -1107,6 +1118,11 @@ class TradingTransformerTrainer:
|
|||||||
# Move model to device
|
# Move model to device
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
# Mixed precision training disabled - causes dtype mismatches
|
||||||
|
# Can be re-enabled if needed, but requires careful dtype management
|
||||||
|
self.use_amp = False
|
||||||
|
self.scaler = None
|
||||||
|
|
||||||
# Optimizer with warmup
|
# Optimizer with warmup
|
||||||
self.optimizer = optim.AdamW(
|
self.optimizer = optim.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
@@ -1136,10 +1152,18 @@ class TradingTransformerTrainer:
|
|||||||
'learning_rates': []
|
'learning_rates': []
|
||||||
}
|
}
|
||||||
|
|
||||||
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
|
||||||
"""Single training step"""
|
"""Single training step with optional gradient accumulation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch
|
||||||
|
accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation)
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
||||||
|
# Only zero gradients if not accumulating
|
||||||
|
if not accumulate_gradients:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
# Move batch to device WITHOUT cloning to avoid version tracking issues
|
# Move batch to device WITHOUT cloning to avoid version tracking issues
|
||||||
@@ -1147,6 +1171,8 @@ class TradingTransformerTrainer:
|
|||||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||||
for k, v in batch.items()}
|
for k, v in batch.items()}
|
||||||
|
|
||||||
|
# Use automatic mixed precision (FP16) for memory efficiency
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
||||||
# Forward pass with multi-timeframe data
|
# Forward pass with multi-timeframe data
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
price_data_1s=batch.get('price_data_1s'),
|
price_data_1s=batch.get('price_data_1s'),
|
||||||
@@ -1199,8 +1225,11 @@ class TradingTransformerTrainer:
|
|||||||
# Use addition instead of += to avoid inplace operation
|
# Use addition instead of += to avoid inplace operation
|
||||||
total_loss = total_loss + 0.1 * confidence_loss
|
total_loss = total_loss + 0.1 * confidence_loss
|
||||||
|
|
||||||
# Backward pass
|
# Backward pass with mixed precision scaling
|
||||||
try:
|
try:
|
||||||
|
if self.use_amp:
|
||||||
|
self.scaler.scale(total_loss).backward()
|
||||||
|
else:
|
||||||
total_loss.backward()
|
total_loss.backward()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "inplace operation" in str(e):
|
if "inplace operation" in str(e):
|
||||||
@@ -1216,11 +1245,22 @@ class TradingTransformerTrainer:
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Only clip gradients and step optimizer if not accumulating
|
||||||
|
if not accumulate_gradients:
|
||||||
|
if self.use_amp:
|
||||||
|
# Unscale gradients before clipping
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
# Gradient clipping
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||||
|
# Optimizer step with scaling
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
# Gradient clipping
|
# Gradient clipping
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
# Calculate accuracy without gradients
|
# Calculate accuracy without gradients
|
||||||
|
|||||||
@@ -836,22 +836,10 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent
|
||||||
|
|
||||||
# Determine actual state size from BaseDataInput
|
# Use known state size instead of building data (which triggers massive API calls)
|
||||||
try:
|
# The state size is determined by BaseDataInput structure and doesn't change
|
||||||
base_data = self.data_provider.build_base_data_input(self.symbol)
|
actual_state_size = 7850 # Known size from BaseDataInput.get_feature_vector()
|
||||||
if base_data:
|
logger.info(f"Using known state size: {actual_state_size}")
|
||||||
actual_state_size = len(base_data.get_feature_vector())
|
|
||||||
logger.info(f"Detected actual state size: {actual_state_size}")
|
|
||||||
else:
|
|
||||||
actual_state_size = 7850 # Fallback based on error message
|
|
||||||
logger.warning(
|
|
||||||
f"Could not determine state size, using fallback: {actual_state_size}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
actual_state_size = 7850 # Fallback based on error message
|
|
||||||
logger.warning(
|
|
||||||
f"Error determining state size: {e}, using fallback: {actual_state_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
action_size = self.config.rl.get("action_space", 3)
|
action_size = self.config.rl.get("action_space", 3)
|
||||||
self.rl_agent = DQNAgent(
|
self.rl_agent = DQNAgent(
|
||||||
|
|||||||
Reference in New Issue
Block a user