raining normalization fix
This commit is contained in:
@@ -443,6 +443,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
||||
|
||||
# Enhanced price prediction head (auxiliary task)
|
||||
# Predicts price change ratio (future_price - current_price) / current_price
|
||||
# Use Tanh to constrain to [-1, 1] range (max 100% change up/down)
|
||||
self.price_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
@@ -450,7 +452,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 1)
|
||||
nn.Linear(config.d_model // 4, 1),
|
||||
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
|
||||
)
|
||||
|
||||
# Additional specialized heads for 46M model
|
||||
@@ -473,6 +476,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
||||
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
||||
# Note: self.timeframes already defined above in input projections
|
||||
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
|
||||
self.next_candle_heads = nn.ModuleDict({
|
||||
tf: nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
@@ -481,11 +485,13 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 5) # OHLCV: [open, high, low, close, volume]
|
||||
nn.Linear(config.d_model // 4, 5), # OHLCV: [open, high, low, close, volume]
|
||||
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
|
||||
) for tf in self.timeframes
|
||||
})
|
||||
|
||||
# BTC next candle prediction head
|
||||
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
|
||||
self.btc_next_candle_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
@@ -493,7 +499,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
|
||||
nn.Linear(config.d_model // 4, 5), # OHLCV for BTC
|
||||
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
|
||||
)
|
||||
|
||||
# NEW: Next pivot point prediction heads for L1-L5 levels
|
||||
@@ -1153,6 +1160,54 @@ class TradingTransformerTrainer:
|
||||
'learning_rates': []
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
|
||||
"""
|
||||
Denormalize price predictions back to real price space
|
||||
|
||||
Args:
|
||||
normalized_values: Tensor of normalized values in [0, 1] range
|
||||
norm_params: Dict with 'price_min' and 'price_max' keys
|
||||
|
||||
Returns:
|
||||
Denormalized tensor in original price space
|
||||
"""
|
||||
price_min = norm_params.get('price_min', 0.0)
|
||||
price_max = norm_params.get('price_max', 1.0)
|
||||
|
||||
if price_max > price_min:
|
||||
return normalized_values * (price_max - price_min) + price_min
|
||||
else:
|
||||
return normalized_values
|
||||
|
||||
@staticmethod
|
||||
def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
|
||||
"""
|
||||
Denormalize a full OHLCV candle back to real values
|
||||
|
||||
Args:
|
||||
normalized_candle: Tensor of shape [..., 5] with normalized OHLCV
|
||||
norm_params: Dict with normalization parameters
|
||||
|
||||
Returns:
|
||||
Denormalized OHLCV tensor
|
||||
"""
|
||||
denorm = normalized_candle.clone()
|
||||
|
||||
# Denormalize OHLC (first 4 values)
|
||||
price_min = norm_params.get('price_min', 0.0)
|
||||
price_max = norm_params.get('price_max', 1.0)
|
||||
if price_max > price_min:
|
||||
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
|
||||
|
||||
# Denormalize volume (5th value)
|
||||
volume_min = norm_params.get('volume_min', 0.0)
|
||||
volume_max = norm_params.get('volume_max', 1.0)
|
||||
if volume_max > volume_min:
|
||||
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
|
||||
|
||||
return denorm
|
||||
|
||||
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
|
||||
"""Single training step with optional gradient accumulation
|
||||
|
||||
@@ -1217,8 +1272,50 @@ class TradingTransformerTrainer:
|
||||
trend_loss = self.price_criterion(trend_pred, trend_target)
|
||||
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
|
||||
|
||||
# NEW: Next candle prediction loss for each timeframe
|
||||
# This trains the model to predict full OHLCV for the next candle on each timeframe
|
||||
candle_loss = torch.tensor(0.0, device=self.device)
|
||||
candle_losses_detail = {} # Track per-timeframe losses (normalized space)
|
||||
candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space)
|
||||
|
||||
if 'next_candles' in outputs:
|
||||
timeframe_losses = []
|
||||
|
||||
# Get normalization parameters if available
|
||||
norm_params = batch.get('norm_params', {})
|
||||
|
||||
# Calculate loss for each timeframe that has target data
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
future_key = f'future_candle_{tf}'
|
||||
|
||||
if tf in outputs['next_candles'] and future_key in batch:
|
||||
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
|
||||
target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized)
|
||||
|
||||
if target_candle is not None and pred_candle.shape == target_candle.shape:
|
||||
# MSE loss on normalized values (used for backprop)
|
||||
tf_loss = self.price_criterion(pred_candle, target_candle)
|
||||
timeframe_losses.append(tf_loss)
|
||||
candle_losses_detail[tf] = tf_loss.item()
|
||||
|
||||
# ALSO calculate denormalized loss for better interpretability
|
||||
if tf in norm_params:
|
||||
with torch.no_grad():
|
||||
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
|
||||
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
|
||||
denorm_loss = self.price_criterion(pred_denorm, target_denorm)
|
||||
candle_losses_denorm[tf] = denorm_loss.item()
|
||||
|
||||
# Average loss across available timeframes
|
||||
if timeframe_losses:
|
||||
candle_loss = torch.stack(timeframe_losses).mean()
|
||||
if candle_losses_denorm:
|
||||
logger.debug(f"Candle losses (normalized): {candle_losses_detail}")
|
||||
logger.debug(f"Candle losses (real prices): {candle_losses_denorm}")
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks
|
||||
# Weight: action=1.0, price=0.1, trend=0.05, candle=0.15
|
||||
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss
|
||||
|
||||
# CRITICAL FIX: Scale loss for gradient accumulation
|
||||
# This prevents gradient explosion when accumulating over multiple batches
|
||||
@@ -1322,7 +1419,9 @@ class TradingTransformerTrainer:
|
||||
'total_loss': total_loss.item(),
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW
|
||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
||||
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
||||
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
||||
'accuracy': accuracy.item(),
|
||||
'candle_accuracy': candle_accuracy,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
@@ -1330,7 +1429,7 @@ class TradingTransformerTrainer:
|
||||
|
||||
# CRITICAL: Delete large tensors to free memory immediately
|
||||
# This prevents memory accumulation across batches
|
||||
del outputs, total_loss, action_loss, price_loss, predictions, accuracy
|
||||
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user