training fix

This commit is contained in:
Dobromir Popov
2025-11-12 15:05:44 +02:00
parent 6f3f862edd
commit 352dc9cbeb
2 changed files with 27 additions and 12 deletions

View File

@@ -1282,7 +1282,13 @@ class TradingTransformerTrainer:
timeframe_losses = []
# Get normalization parameters if available
norm_params = batch.get('norm_params', {})
# norm_params may be a dict or a list of dicts (one per sample in batch)
norm_params_raw = batch.get('norm_params', {})
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
# If it's a list, use the first one (batch size is typically 1)
norm_params = norm_params_raw[0]
else:
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
# Calculate loss for each timeframe that has target data
for tf in ['1s', '1m', '1h', '1d']: