IMPLEMENTED: WIP; realtime candle predictions training

This commit is contained in:
Dobromir Popov
2025-11-22 17:57:58 +02:00
parent 423132dc8f
commit 26cbfd771b
4 changed files with 672 additions and 74 deletions

View File

@@ -1446,33 +1446,39 @@ class TradingTransformerTrainer:
candle_rmse = {}
if 'next_candles' in outputs:
# Use 1m timeframe as primary metric
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
# Use 1s or 1m timeframe as primary metric (try 1s first)
if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch:
pred_candle = outputs['next_candles']['1s'] # [batch, 5]
actual_candle = batch['future_candle_1s'] # [batch, 5]
elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch['future_candle_1m'] # [batch, 5]
else:
pred_candle = None
actual_candle = None
if actual_candle is not None and pred_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0