T model trend prediction added
This commit is contained in:
@@ -1203,8 +1203,22 @@ class TradingTransformerTrainer:
|
||||
|
||||
price_loss = self.price_criterion(price_pred, price_target)
|
||||
|
||||
# NEW: Trend analysis loss (if trend_target provided)
|
||||
trend_loss = torch.tensor(0.0, device=self.device)
|
||||
if 'trend_target' in batch and 'trend_analysis' in outputs:
|
||||
trend_pred = torch.cat([
|
||||
outputs['trend_analysis']['angle_radians'],
|
||||
outputs['trend_analysis']['steepness'],
|
||||
outputs['trend_analysis']['direction']
|
||||
], dim=1) # [batch, 3]
|
||||
|
||||
trend_target = batch['trend_target']
|
||||
if trend_pred.shape == trend_target.shape:
|
||||
trend_loss = self.price_criterion(trend_pred, trend_target)
|
||||
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks
|
||||
|
||||
# CRITICAL FIX: Scale loss for gradient accumulation
|
||||
# This prevents gradient explosion when accumulating over multiple batches
|
||||
@@ -1308,6 +1322,7 @@ class TradingTransformerTrainer:
|
||||
'total_loss': total_loss.item(),
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW
|
||||
'accuracy': accuracy.item(),
|
||||
'candle_accuracy': candle_accuracy,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
|
||||
Reference in New Issue
Block a user