fix training - loss calculation;

added memory guard
This commit is contained in:
Dobromir Popov
2025-11-13 15:58:42 +02:00
parent bf2a6cf96e
commit 13b6fafaf8
3 changed files with 357 additions and 20 deletions

View File

@@ -1407,23 +1407,63 @@ class TradingTransformerTrainer:
self.scheduler.step()
# Calculate accuracy without gradients
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
with torch.no_grad():
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
# Calculate candle prediction accuracy (price direction)
candle_accuracy = 0.0
if 'next_candles' in outputs and 'future_prices' in batch:
# Use 1m timeframe prediction as primary
if '1m' in outputs['next_candles']:
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
# Predicted close is the 4th value (index 3)
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
actual_close_change = batch['future_prices'] # Actual price change ratio
candle_rmse = {}
if 'next_candles' in outputs:
# Use 1m timeframe as primary metric
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch['future_candle_1m'] # [batch, 5]
# Check if direction matches (both positive or both negative)
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
candle_accuracy = direction_match.mean().item()
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0
if 'trend_analysis' in outputs and 'trend_target' in batch:
pred_angle = outputs['trend_analysis']['angle_radians']
pred_steepness = outputs['trend_analysis']['steepness']
actual_angle = batch['trend_target'][:, 0:1]
actual_steepness = batch['trend_target'][:, 1:2]
# Angle error (degrees)
angle_error_rad = torch.abs(pred_angle - actual_angle)
angle_error_deg = angle_error_rad * 180.0 / 3.14159
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
# Steepness error (percentage)
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
# LEGACY: Action accuracy (for comparison)
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
# Extract values and delete tensors to free memory
result = {
@@ -1433,14 +1473,20 @@ class TradingTransformerTrainer:
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
# NEW: Realistic accuracy metrics based on next candle prediction
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
'candle_accuracy': candle_accuracy, # Same as accuracy
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
'action_accuracy': action_accuracy, # Legacy action accuracy
'learning_rate': self.scheduler.get_last_lr()[0]
}
# CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
if torch.cuda.is_available():
torch.cuda.empty_cache()