fix T training memory usage (due for more improvement)
This commit is contained in:
@@ -336,7 +336,9 @@ class RealTrainingAdapter:
|
||||
# Get training config
|
||||
training_config = test_case.get('training_config', {})
|
||||
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per batch
|
||||
# Reduce sequence length to avoid OOM - 200 candles is more reasonable
|
||||
# With 5 timeframes, this gives 1000 total positions vs 3000 with 600 candles
|
||||
candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch
|
||||
|
||||
# Determine secondary symbol based on primary symbol
|
||||
# ETH/SOL -> BTC, BTC -> ETH
|
||||
@@ -1183,8 +1185,13 @@ class RealTrainingAdapter:
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
||||
|
||||
# Target sequence length for all timeframes
|
||||
target_seq_len = 600
|
||||
# Target sequence length - use actual data length (typically 200 candles)
|
||||
# Find the first available timeframe to determine sequence length
|
||||
target_seq_len = 200 # Default
|
||||
for tf_data in timeframes.values():
|
||||
if tf_data and 'close' in tf_data and len(tf_data['close']) > 0:
|
||||
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200 to avoid OOM
|
||||
break
|
||||
|
||||
# Extract each timeframe (returns None if not available)
|
||||
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
||||
@@ -1219,8 +1226,8 @@ class RealTrainingAdapter:
|
||||
return None
|
||||
|
||||
# Create placeholder COB data (zeros if not available)
|
||||
# COB data shape: [1, 600, 100] to match new sequence length
|
||||
cob_data = torch.zeros(1, 600, 100, dtype=torch.float32)
|
||||
# COB data shape: [1, target_seq_len, 100] to match sequence length
|
||||
cob_data = torch.zeros(1, target_seq_len, 100, dtype=torch.float32)
|
||||
|
||||
# Create technical indicators from reference timeframe
|
||||
tech_features = []
|
||||
@@ -1487,9 +1494,10 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.info(f" Converted {len(training_data)} samples to {len(converted_batches)} training batches")
|
||||
|
||||
# Group single-sample batches into mini-batches for efficient training
|
||||
# Small batch size (5) for better gradient updates with limited training data
|
||||
mini_batch_size = 5 # Small batches work better with ~255 samples
|
||||
# Use batch size of 1 to avoid OOM with large sequence lengths
|
||||
# With 5 timeframes * 600 candles = 3000 sequence positions per sample,
|
||||
# even batch_size=5 causes 15,000 positions which is too large for GPU
|
||||
mini_batch_size = 1 # Process one sample at a time to avoid OOM
|
||||
|
||||
def _combine_batches(batch_list: List[Dict[str, 'torch.Tensor']]) -> Dict[str, 'torch.Tensor']:
|
||||
combined: Dict[str, 'torch.Tensor'] = {}
|
||||
@@ -1521,7 +1529,10 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.info(f" Grouped into {len(grouped_batches)} mini-batches (target size {mini_batch_size})")
|
||||
|
||||
# Train using train_step for each mini-batch
|
||||
# Train using train_step for each mini-batch with gradient accumulation
|
||||
# Accumulate gradients over multiple batches to simulate larger batch size
|
||||
accumulation_steps = 5 # Accumulate 5 batches before optimizer step
|
||||
|
||||
for epoch in range(session.total_epochs):
|
||||
epoch_loss = 0.0
|
||||
epoch_accuracy = 0.0
|
||||
@@ -1529,8 +1540,11 @@ class RealTrainingAdapter:
|
||||
|
||||
for i, batch in enumerate(grouped_batches):
|
||||
try:
|
||||
# Determine if this is an accumulation step or optimizer step
|
||||
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||||
|
||||
# Call the trainer's train_step method with proper batch format
|
||||
result = trainer.train_step(batch)
|
||||
result = trainer.train_step(batch, accumulate_gradients=is_accumulation_step)
|
||||
|
||||
if result is not None:
|
||||
batch_loss = result.get('total_loss', 0.0)
|
||||
@@ -1539,14 +1553,14 @@ class RealTrainingAdapter:
|
||||
epoch_accuracy += batch_accuracy
|
||||
num_batches += 1
|
||||
|
||||
# Log first batch and every 100th batch for debugging
|
||||
if (i + 1) == 1 or (i + 1) % 100 == 0:
|
||||
logger.info(f" Batch {i + 1}/{len(converted_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
|
||||
# Log first batch and every 10th batch for debugging
|
||||
if (i + 1) == 1 or (i + 1) % 10 == 0:
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
# Clear CUDA cache periodically to prevent memory leak
|
||||
if torch.cuda.is_available() and (i + 1) % 5 == 0:
|
||||
# Clear CUDA cache after optimizer step (not accumulation step)
|
||||
if torch.cuda.is_available() and not is_accumulation_step:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user