raining normalization fix

This commit is contained in:
Dobromir Popov
2025-11-12 14:36:28 +02:00
parent 4c04503f3e
commit a7a22334fb
5 changed files with 800 additions and 50 deletions

View File

@@ -443,6 +443,8 @@ class AdvancedTradingTransformer(nn.Module):
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
# Enhanced price prediction head (auxiliary task)
# Predicts price change ratio (future_price - current_price) / current_price
# Use Tanh to constrain to [-1, 1] range (max 100% change up/down)
self.price_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
@@ -450,7 +452,8 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 1)
nn.Linear(config.d_model // 4, 1),
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
)
# Additional specialized heads for 46M model
@@ -473,6 +476,7 @@ class AdvancedTradingTransformer(nn.Module):
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
# Note: self.timeframes already defined above in input projections
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
self.next_candle_heads = nn.ModuleDict({
tf: nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
@@ -481,11 +485,13 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 5) # OHLCV: [open, high, low, close, volume]
nn.Linear(config.d_model // 4, 5), # OHLCV: [open, high, low, close, volume]
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
) for tf in self.timeframes
})
# BTC next candle prediction head
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
self.btc_next_candle_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
@@ -493,7 +499,8 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
nn.Linear(config.d_model // 4, 5), # OHLCV for BTC
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
)
# NEW: Next pivot point prediction heads for L1-L5 levels
@@ -1153,6 +1160,54 @@ class TradingTransformerTrainer:
'learning_rates': []
}
@staticmethod
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
Denormalize price predictions back to real price space
Args:
normalized_values: Tensor of normalized values in [0, 1] range
norm_params: Dict with 'price_min' and 'price_max' keys
Returns:
Denormalized tensor in original price space
"""
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
return normalized_values * (price_max - price_min) + price_min
else:
return normalized_values
@staticmethod
def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
Denormalize a full OHLCV candle back to real values
Args:
normalized_candle: Tensor of shape [..., 5] with normalized OHLCV
norm_params: Dict with normalization parameters
Returns:
Denormalized OHLCV tensor
"""
denorm = normalized_candle.clone()
# Denormalize OHLC (first 4 values)
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
# Denormalize volume (5th value)
volume_min = norm_params.get('volume_min', 0.0)
volume_max = norm_params.get('volume_max', 1.0)
if volume_max > volume_min:
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
return denorm
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
"""Single training step with optional gradient accumulation
@@ -1217,8 +1272,50 @@ class TradingTransformerTrainer:
trend_loss = self.price_criterion(trend_pred, trend_target)
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
# NEW: Next candle prediction loss for each timeframe
# This trains the model to predict full OHLCV for the next candle on each timeframe
candle_loss = torch.tensor(0.0, device=self.device)
candle_losses_detail = {} # Track per-timeframe losses (normalized space)
candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space)
if 'next_candles' in outputs:
timeframe_losses = []
# Get normalization parameters if available
norm_params = batch.get('norm_params', {})
# Calculate loss for each timeframe that has target data
for tf in ['1s', '1m', '1h', '1d']:
future_key = f'future_candle_{tf}'
if tf in outputs['next_candles'] and future_key in batch:
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized)
if target_candle is not None and pred_candle.shape == target_candle.shape:
# MSE loss on normalized values (used for backprop)
tf_loss = self.price_criterion(pred_candle, target_candle)
timeframe_losses.append(tf_loss)
candle_losses_detail[tf] = tf_loss.item()
# ALSO calculate denormalized loss for better interpretability
if tf in norm_params:
with torch.no_grad():
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
denorm_loss = self.price_criterion(pred_denorm, target_denorm)
candle_losses_denorm[tf] = denorm_loss.item()
# Average loss across available timeframes
if timeframe_losses:
candle_loss = torch.stack(timeframe_losses).mean()
if candle_losses_denorm:
logger.debug(f"Candle losses (normalized): {candle_losses_detail}")
logger.debug(f"Candle losses (real prices): {candle_losses_denorm}")
# Start with base losses - avoid inplace operations on computation graph
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks
# Weight: action=1.0, price=0.1, trend=0.05, candle=0.15
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss
# CRITICAL FIX: Scale loss for gradient accumulation
# This prevents gradient explosion when accumulating over multiple batches
@@ -1322,7 +1419,9 @@ class TradingTransformerTrainer:
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
'learning_rate': self.scheduler.get_last_lr()[0]
@@ -1330,7 +1429,7 @@ class TradingTransformerTrainer:
# CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, predictions, accuracy
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
if torch.cuda.is_available():
torch.cuda.empty_cache()