training fix
This commit is contained in:
@@ -1797,16 +1797,25 @@ class RealTrainingAdapter:
|
||||
combined: Dict[str, 'torch.Tensor'] = {}
|
||||
keys = batch_list[0].keys()
|
||||
for key in keys:
|
||||
tensors = [b[key] for b in batch_list if b[key] is not None]
|
||||
values = [b[key] for b in batch_list if b[key] is not None]
|
||||
# Skip keys where all values are None
|
||||
if not tensors:
|
||||
if not values:
|
||||
combined[key] = None
|
||||
continue
|
||||
|
||||
# Special handling for non-tensor keys (like norm_params which is a dict)
|
||||
if key == 'norm_params':
|
||||
# Keep norm_params as a list of dicts (one per sample in batch)
|
||||
combined[key] = values
|
||||
continue
|
||||
|
||||
# For tensors, concatenate them
|
||||
try:
|
||||
combined[key] = torch.cat(tensors, dim=0)
|
||||
except RuntimeError as concat_error:
|
||||
logger.error(f"Failed to concatenate key '{key}' for mini-batch: {concat_error}")
|
||||
raise
|
||||
combined[key] = torch.cat(values, dim=0)
|
||||
except (RuntimeError, TypeError) as concat_error:
|
||||
# If concatenation fails (e.g., not a tensor), keep as list
|
||||
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
|
||||
combined[key] = values
|
||||
return combined
|
||||
|
||||
grouped_batches: List[Dict[str, torch.Tensor]] = []
|
||||
@@ -1944,11 +1953,11 @@ class RealTrainingAdapter:
|
||||
'epoch': epoch + 1,
|
||||
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
|
||||
},
|
||||
training_metadata={
|
||||
'num_samples': len(training_data),
|
||||
'num_batches': num_batches,
|
||||
'training_id': training_id
|
||||
},
|
||||
training_metadata={
|
||||
'num_samples': len(training_data),
|
||||
'num_batches': num_batches,
|
||||
'training_id': session.training_id
|
||||
},
|
||||
file_path=checkpoint_path,
|
||||
performance_score=float(avg_accuracy), # Use accuracy as score
|
||||
is_active=True
|
||||
|
||||
@@ -1282,7 +1282,13 @@ class TradingTransformerTrainer:
|
||||
timeframe_losses = []
|
||||
|
||||
# Get normalization parameters if available
|
||||
norm_params = batch.get('norm_params', {})
|
||||
# norm_params may be a dict or a list of dicts (one per sample in batch)
|
||||
norm_params_raw = batch.get('norm_params', {})
|
||||
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
|
||||
# If it's a list, use the first one (batch size is typically 1)
|
||||
norm_params = norm_params_raw[0]
|
||||
else:
|
||||
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
|
||||
|
||||
# Calculate loss for each timeframe that has target data
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
|
||||
Reference in New Issue
Block a user