training fix
This commit is contained in:
@@ -1282,7 +1282,13 @@ class TradingTransformerTrainer:
|
||||
timeframe_losses = []
|
||||
|
||||
# Get normalization parameters if available
|
||||
norm_params = batch.get('norm_params', {})
|
||||
# norm_params may be a dict or a list of dicts (one per sample in batch)
|
||||
norm_params_raw = batch.get('norm_params', {})
|
||||
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
|
||||
# If it's a list, use the first one (batch size is typically 1)
|
||||
norm_params = norm_params_raw[0]
|
||||
else:
|
||||
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
|
||||
|
||||
# Calculate loss for each timeframe that has target data
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
|
||||
Reference in New Issue
Block a user