training fix
This commit is contained in:
@@ -1797,16 +1797,25 @@ class RealTrainingAdapter:
|
|||||||
combined: Dict[str, 'torch.Tensor'] = {}
|
combined: Dict[str, 'torch.Tensor'] = {}
|
||||||
keys = batch_list[0].keys()
|
keys = batch_list[0].keys()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
tensors = [b[key] for b in batch_list if b[key] is not None]
|
values = [b[key] for b in batch_list if b[key] is not None]
|
||||||
# Skip keys where all values are None
|
# Skip keys where all values are None
|
||||||
if not tensors:
|
if not values:
|
||||||
combined[key] = None
|
combined[key] = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Special handling for non-tensor keys (like norm_params which is a dict)
|
||||||
|
if key == 'norm_params':
|
||||||
|
# Keep norm_params as a list of dicts (one per sample in batch)
|
||||||
|
combined[key] = values
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For tensors, concatenate them
|
||||||
try:
|
try:
|
||||||
combined[key] = torch.cat(tensors, dim=0)
|
combined[key] = torch.cat(values, dim=0)
|
||||||
except RuntimeError as concat_error:
|
except (RuntimeError, TypeError) as concat_error:
|
||||||
logger.error(f"Failed to concatenate key '{key}' for mini-batch: {concat_error}")
|
# If concatenation fails (e.g., not a tensor), keep as list
|
||||||
raise
|
logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}")
|
||||||
|
combined[key] = values
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
grouped_batches: List[Dict[str, torch.Tensor]] = []
|
grouped_batches: List[Dict[str, torch.Tensor]] = []
|
||||||
@@ -1944,11 +1953,11 @@ class RealTrainingAdapter:
|
|||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
|
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
|
||||||
},
|
},
|
||||||
training_metadata={
|
training_metadata={
|
||||||
'num_samples': len(training_data),
|
'num_samples': len(training_data),
|
||||||
'num_batches': num_batches,
|
'num_batches': num_batches,
|
||||||
'training_id': training_id
|
'training_id': session.training_id
|
||||||
},
|
},
|
||||||
file_path=checkpoint_path,
|
file_path=checkpoint_path,
|
||||||
performance_score=float(avg_accuracy), # Use accuracy as score
|
performance_score=float(avg_accuracy), # Use accuracy as score
|
||||||
is_active=True
|
is_active=True
|
||||||
|
|||||||
@@ -1282,7 +1282,13 @@ class TradingTransformerTrainer:
|
|||||||
timeframe_losses = []
|
timeframe_losses = []
|
||||||
|
|
||||||
# Get normalization parameters if available
|
# Get normalization parameters if available
|
||||||
norm_params = batch.get('norm_params', {})
|
# norm_params may be a dict or a list of dicts (one per sample in batch)
|
||||||
|
norm_params_raw = batch.get('norm_params', {})
|
||||||
|
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
|
||||||
|
# If it's a list, use the first one (batch size is typically 1)
|
||||||
|
norm_params = norm_params_raw[0]
|
||||||
|
else:
|
||||||
|
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
|
||||||
|
|
||||||
# Calculate loss for each timeframe that has target data
|
# Calculate loss for each timeframe that has target data
|
||||||
for tf in ['1s', '1m', '1h', '1d']:
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
|
|||||||
Reference in New Issue
Block a user