fix T training memory usage (due for more improvement)

This commit is contained in:
Dobromir Popov
2025-11-06 15:54:26 +02:00
parent 738c7cb854
commit 76e3bb6a61
3 changed files with 114 additions and 72 deletions

View File

@@ -336,7 +336,9 @@ class RealTrainingAdapter:
# Get training config
training_config = test_case.get('training_config', {})
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per batch
# Reduce sequence length to avoid OOM - 200 candles is more reasonable
# With 5 timeframes, this gives 1000 total positions vs 3000 with 600 candles
candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch
# Determine secondary symbol based on primary symbol
# ETH/SOL -> BTC, BTC -> ETH
@@ -1183,8 +1185,13 @@ class RealTrainingAdapter:
timeframes = market_state.get('timeframes', {})
secondary_timeframes = market_state.get('secondary_timeframes', {})
# Target sequence length for all timeframes
target_seq_len = 600
# Target sequence length - use actual data length (typically 200 candles)
# Find the first available timeframe to determine sequence length
target_seq_len = 200 # Default
for tf_data in timeframes.values():
if tf_data and 'close' in tf_data and len(tf_data['close']) > 0:
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200 to avoid OOM
break
# Extract each timeframe (returns None if not available)
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
@@ -1219,8 +1226,8 @@ class RealTrainingAdapter:
return None
# Create placeholder COB data (zeros if not available)
# COB data shape: [1, 600, 100] to match new sequence length
cob_data = torch.zeros(1, 600, 100, dtype=torch.float32)
# COB data shape: [1, target_seq_len, 100] to match sequence length
cob_data = torch.zeros(1, target_seq_len, 100, dtype=torch.float32)
# Create technical indicators from reference timeframe
tech_features = []
@@ -1487,9 +1494,10 @@ class RealTrainingAdapter:
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
# Group single-sample batches into mini-batches for efficient training
# Small batch size (5) for better gradient updates with limited training data
mini_batch_size = 5 # Small batches work better with ~255 samples
# Use batch size of 1 to avoid OOM with large sequence lengths
# With 5 timeframes * 600 candles = 3000 sequence positions per sample,
# even batch_size=5 causes 15,000 positions which is too large for GPU
mini_batch_size = 1 # Process one sample at a time to avoid OOM
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
combined: Dict[str, 'torch.Tensor'] = {}
@@ -1521,7 +1529,10 @@ class RealTrainingAdapter:
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
# Train using train_step for each mini-batch
# Train using train_step for each mini-batch with gradient accumulation
# Accumulate gradients over multiple batches to simulate larger batch size
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
for epoch in range(session.total_epochs):
epoch_loss = 0.0
epoch_accuracy = 0.0
@@ -1529,8 +1540,11 @@ class RealTrainingAdapter:
for i, batch in enumerate(grouped_batches):
try:
# Determine if this is an accumulation step or optimizer step
is_accumulation_step = (i + 1) % accumulation_steps != 0
# Call the trainer's train_step method with proper batch format
result = trainer.train_step(batch)
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
if result is not None:
batch_loss = result.get('total_loss', 0.0)
@@ -1539,14 +1553,14 @@ class RealTrainingAdapter:
epoch_accuracy += batch_accuracy
num_batches += 1
# Log first batch and every 100th batch for debugging
if (i + 1) == 1 or (i + 1) % 100 == 0:
logger.info(f" Batch {i + 1}/{len(converted_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
# Log first batch and every 10th batch for debugging
if (i + 1) == 1 or (i + 1) % 10 == 0:
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# Clear CUDA cache periodically to prevent memory leak
if torch.cuda.is_available() and (i + 1) % 5 == 0:
# Clear CUDA cache after optimizer step (not accumulation step)
if torch.cuda.is_available() and not is_accumulation_step:
torch.cuda.empty_cache()
except Exception as e: