This commit is contained in:
Dobromir Popov
2025-12-08 19:28:27 +02:00
parent 8263534b74
commit cf4808aa47
4 changed files with 226 additions and 82 deletions

View File

@@ -1407,18 +1407,37 @@ class TradingTransformerTrainer:
)
# Calculate losses (use batch_on_device for consistency)
# Handle case where actions key is missing (e.g., when no timeframe data available)
if 'actions' not in batch_on_device:
logger.warning("No 'actions' key in batch - skipping this training step")
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'candle_accuracy': 0.0,
'trend_accuracy': 0.0,
'action_accuracy': 0.0
}
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
# FIXED: Ensure shapes match for MSELoss
price_pred = outputs['price_prediction']
price_target = batch_on_device['future_prices']
# Both should be [batch, 1], but ensure they match
if price_pred.shape != price_target.shape:
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
price_target = price_target.view(price_pred.shape)
price_loss = self.price_criterion(price_pred, price_target)
# Handle case where future_prices key is missing
if 'future_prices' not in batch_on_device:
logger.warning("No 'future_prices' key in batch - using zero loss for price prediction")
price_loss = torch.tensor(0.0, device=self.device)
else:
price_target = batch_on_device['future_prices']
# Both should be [batch, 1], but ensure they match
if price_pred.shape != price_target.shape:
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
price_target = price_target.view(price_pred.shape)
price_loss = self.price_criterion(price_pred, price_target)
# NEW: Trend analysis loss (if trend_target provided)
trend_loss = torch.tensor(0.0, device=self.device)
@@ -1677,8 +1696,11 @@ class TradingTransformerTrainer:
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
# LEGACY: Action accuracy (for comparison)
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
if 'actions' in batch_on_device:
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
else:
action_accuracy = 0.0
# Extract values and delete tensors to free memory
result = {