T model trend prediction added

This commit is contained in:
Dobromir Popov
2025-11-10 20:12:22 +02:00
parent 999dea9eb0
commit 27039c70a3
2 changed files with 152 additions and 70 deletions

View File

@@ -1203,8 +1203,22 @@ class TradingTransformerTrainer:
price_loss = self.price_criterion(price_pred, price_target)
# NEW: Trend analysis loss (if trend_target provided)
trend_loss = torch.tensor(0.0, device=self.device)
if 'trend_target' in batch and 'trend_analysis' in outputs:
trend_pred = torch.cat([
outputs['trend_analysis']['angle_radians'],
outputs['trend_analysis']['steepness'],
outputs['trend_analysis']['direction']
], dim=1) # [batch, 3]
trend_target = batch['trend_target']
if trend_pred.shape == trend_target.shape:
trend_loss = self.price_criterion(trend_pred, trend_target)
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
# Start with base losses - avoid inplace operations on computation graph
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks
# CRITICAL FIX: Scale loss for gradient accumulation
# This prevents gradient explosion when accumulating over multiple batches
@@ -1308,6 +1322,7 @@ class TradingTransformerTrainer:
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW
'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
'learning_rate': self.scheduler.get_last_lr()[0]