diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index d83ec3d..40f277c 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -1797,16 +1797,25 @@ class RealTrainingAdapter: combined: Dict[str, 'torch.Tensor'] = {} keys = batch_list[0].keys() for key in keys: - tensors = [b[key] for b in batch_list if b[key] is not None] + values = [b[key] for b in batch_list if b[key] is not None] # Skip keys where all values are None - if not tensors: + if not values: combined[key] = None continue + + # Special handling for non-tensor keys (like norm_params which is a dict) + if key == 'norm_params': + # Keep norm_params as a list of dicts (one per sample in batch) + combined[key] = values + continue + + # For tensors, concatenate them try: - combined[key] = torch.cat(tensors, dim=0) - except RuntimeError as concat_error: - logger.error(f"Failed to concatenate key '{key}' for mini-batch: {concat_error}") - raise + combined[key] = torch.cat(values, dim=0) + except (RuntimeError, TypeError) as concat_error: + # If concatenation fails (e.g., not a tensor), keep as list + logger.debug(f"Could not concatenate key '{key}', keeping as list: {concat_error}") + combined[key] = values return combined grouped_batches: List[Dict[str, torch.Tensor]] = [] @@ -1944,11 +1953,11 @@ class RealTrainingAdapter: 'epoch': epoch + 1, 'learning_rate': float(trainer.scheduler.get_last_lr()[0]) }, - training_metadata={ - 'num_samples': len(training_data), - 'num_batches': num_batches, - 'training_id': training_id - }, + training_metadata={ + 'num_samples': len(training_data), + 'num_batches': num_batches, + 'training_id': session.training_id + }, file_path=checkpoint_path, performance_score=float(avg_accuracy), # Use accuracy as score is_active=True diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index 53efca8..0d8684e 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -1282,7 +1282,13 @@ class TradingTransformerTrainer: timeframe_losses = [] # Get normalization parameters if available - norm_params = batch.get('norm_params', {}) + # norm_params may be a dict or a list of dicts (one per sample in batch) + norm_params_raw = batch.get('norm_params', {}) + if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0: + # If it's a list, use the first one (batch size is typically 1) + norm_params = norm_params_raw[0] + else: + norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {} # Calculate loss for each timeframe that has target data for tf in ['1s', '1m', '1h', '1d']: