fix training - loss calculation;
added memory guard
This commit is contained in:
@@ -1407,23 +1407,63 @@ class TradingTransformerTrainer:
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy without gradients
|
||||
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
|
||||
with torch.no_grad():
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
# Calculate candle prediction accuracy (price direction)
|
||||
candle_accuracy = 0.0
|
||||
if 'next_candles' in outputs and 'future_prices' in batch:
|
||||
# Use 1m timeframe prediction as primary
|
||||
if '1m' in outputs['next_candles']:
|
||||
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
# Predicted close is the 4th value (index 3)
|
||||
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
|
||||
actual_close_change = batch['future_prices'] # Actual price change ratio
|
||||
candle_rmse = {}
|
||||
|
||||
if 'next_candles' in outputs:
|
||||
# Use 1m timeframe as primary metric
|
||||
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
|
||||
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
actual_candle = batch['future_candle_1m'] # [batch, 5]
|
||||
|
||||
# Check if direction matches (both positive or both negative)
|
||||
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
|
||||
candle_accuracy = direction_match.mean().item()
|
||||
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
|
||||
# Calculate RMSE for each OHLCV component
|
||||
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
|
||||
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
|
||||
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
|
||||
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
|
||||
|
||||
# Average RMSE for OHLC (exclude volume)
|
||||
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
|
||||
|
||||
# Convert to accuracy: lower RMSE = higher accuracy
|
||||
# Normalize by price range
|
||||
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
|
||||
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
|
||||
|
||||
candle_rmse = {
|
||||
'open': rmse_open.item(),
|
||||
'high': rmse_high.item(),
|
||||
'low': rmse_low.item(),
|
||||
'close': rmse_close.item(),
|
||||
'avg': avg_rmse.item()
|
||||
}
|
||||
|
||||
# SECONDARY: Trend vector prediction accuracy
|
||||
trend_accuracy = 0.0
|
||||
if 'trend_analysis' in outputs and 'trend_target' in batch:
|
||||
pred_angle = outputs['trend_analysis']['angle_radians']
|
||||
pred_steepness = outputs['trend_analysis']['steepness']
|
||||
|
||||
actual_angle = batch['trend_target'][:, 0:1]
|
||||
actual_steepness = batch['trend_target'][:, 1:2]
|
||||
|
||||
# Angle error (degrees)
|
||||
angle_error_rad = torch.abs(pred_angle - actual_angle)
|
||||
angle_error_deg = angle_error_rad * 180.0 / 3.14159
|
||||
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
|
||||
|
||||
# Steepness error (percentage)
|
||||
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
|
||||
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
|
||||
|
||||
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
||||
|
||||
# LEGACY: Action accuracy (for comparison)
|
||||
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
|
||||
|
||||
# Extract values and delete tensors to free memory
|
||||
result = {
|
||||
@@ -1433,14 +1473,20 @@ class TradingTransformerTrainer:
|
||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
||||
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
||||
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
||||
'accuracy': accuracy.item(),
|
||||
'candle_accuracy': candle_accuracy,
|
||||
|
||||
# NEW: Realistic accuracy metrics based on next candle prediction
|
||||
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
|
||||
'candle_accuracy': candle_accuracy, # Same as accuracy
|
||||
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
|
||||
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
|
||||
'action_accuracy': action_accuracy, # Legacy action accuracy
|
||||
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
# CRITICAL: Delete large tensors to free memory immediately
|
||||
# This prevents memory accumulation across batches
|
||||
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
|
||||
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user