training fix

This commit is contained in:
Dobromir Popov
2025-11-12 15:05:44 +02:00
parent 6f3f862edd
commit 352dc9cbeb
2 changed files with 27 additions and 12 deletions

View File

@@ -1797,16 +1797,25 @@ class RealTrainingAdapter:
combined: Dict[str, 'torch.Tensor'] = {} combined: Dict[str, 'torch.Tensor'] = {}
keys = batch_list[0].keys() keys = batch_list[0].keys()
for key in keys: for key in keys:
tensors = [b[key] for b in batch_list if b[key] is not None] values = [b[key] for b in batch_list if b[key] is not None]
# Skip keys where all values are None # Skip keys where all values are None
if not tensors: if not values:
combined[key] = None combined[key] = None
continue continue
# Special handling for non-tensor keys (like norm_params which is a dict)
if key == 'norm_params':
# Keep norm_params as a list of dicts (one per sample in batch)
combined[key] = values
continue
# For tensors, concatenate them
try: try:
combined[key] = torch.cat(tensors, dim=0) combined[key] = torch.cat(values, dim=0)
except RuntimeError as concat_error: except (RuntimeError, TypeError) as concat_error:
logger.error(f"Failed to concatenate key '{key}' for mini-batch: {concat_error}") # If concatenation fails (e.g., not a tensor), keep as list
raise logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
combined[key] = values
return combined return combined
grouped_batches: List[Dict[str, torch.Tensor]] = [] grouped_batches: List[Dict[str, torch.Tensor]] = []
@@ -1944,11 +1953,11 @@ class RealTrainingAdapter:
'epoch': epoch + 1, 'epoch': epoch + 1,
'learning_rate': float(trainer.scheduler.get_last_lr()[0]) 'learning_rate': float(trainer.scheduler.get_last_lr()[0])
}, },
training_metadata={ training_metadata={
'num_samples': len(training_data), 'num_samples': len(training_data),
'num_batches': num_batches, 'num_batches': num_batches,
'training_id': training_id 'training_id': session.training_id
}, },
file_path=checkpoint_path, file_path=checkpoint_path,
performance_score=float(avg_accuracy), # Use accuracy as score performance_score=float(avg_accuracy), # Use accuracy as score
is_active=True is_active=True

View File

@@ -1282,7 +1282,13 @@ class TradingTransformerTrainer:
timeframe_losses = [] timeframe_losses = []
# Get normalization parameters if available # Get normalization parameters if available
norm_params = batch.get('norm_params', {}) # norm_params may be a dict or a list of dicts (one per sample in batch)
norm_params_raw = batch.get('norm_params', {})
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
# If it's a list, use the first one (batch size is typically 1)
norm_params = norm_params_raw[0]
else:
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
# Calculate loss for each timeframe that has target data # Calculate loss for each timeframe that has target data
for tf in ['1s', '1m', '1h', '1d']: for tf in ['1s', '1m', '1h', '1d']: