training fix

This commit is contained in:
Dobromir Popov
2025-11-12 15:05:44 +02:00
parent 6f3f862edd
commit 352dc9cbeb
2 changed files with 27 additions and 12 deletions

View File

@@ -1797,16 +1797,25 @@ class RealTrainingAdapter:
combined: Dict[str, 'torch.Tensor'] = {}
keys = batch_list[0].keys()
for key in keys:
tensors = [b[key] for b in batch_list if b[key] is not None]
values = [b[key] for b in batch_list if b[key] is not None]
# Skip keys where all values are None
if not tensors:
if not values:
combined[key] = None
continue
# Special handling for non-tensor keys (like norm_params which is a dict)
if key == 'norm_params':
# Keep norm_params as a list of dicts (one per sample in batch)
combined[key] = values
continue
# For tensors, concatenate them
try:
combined[key] = torch.cat(tensors, dim=0)
except RuntimeError as concat_error:
logger.error(f"Failed to concatenate key '{key}' for mini-batch: {concat_error}")
raise
combined[key] = torch.cat(values, dim=0)
except (RuntimeError, TypeError) as concat_error:
# If concatenation fails (e.g., not a tensor), keep as list
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
combined[key] = values
return combined
grouped_batches: List[Dict[str, torch.Tensor]] = []
@@ -1944,11 +1953,11 @@ class RealTrainingAdapter:
'epoch': epoch + 1,
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
},
training_metadata={
'num_samples': len(training_data),
'num_batches': num_batches,
'training_id': training_id
},
training_metadata={
'num_samples': len(training_data),
'num_batches': num_batches,
'training_id': session.training_id
},
file_path=checkpoint_path,
performance_score=float(avg_accuracy), # Use accuracy as score
is_active=True

View File

@@ -1282,7 +1282,13 @@ class TradingTransformerTrainer:
timeframe_losses = []
# Get normalization parameters if available
norm_params = batch.get('norm_params', {})
# norm_params may be a dict or a list of dicts (one per sample in batch)
norm_params_raw = batch.get('norm_params', {})
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
# If it's a list, use the first one (batch size is typically 1)
norm_params = norm_params_raw[0]
else:
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
# Calculate loss for each timeframe that has target data
for tf in ['1s', '1m', '1h', '1d']: