fx T checkpoint and model loss measure

This commit is contained in:
Dobromir Popov
2025-11-10 12:41:39 +02:00
parent 86ae8b499b
commit a2d34c6d7c
2 changed files with 138 additions and 1 deletions

View File

@@ -1267,6 +1267,20 @@ class TradingTransformerTrainer:
with torch.no_grad():
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
# Calculate candle prediction accuracy (price direction)
candle_accuracy = 0.0
if 'next_candles' in outputs and 'future_prices' in batch:
# Use 1m timeframe prediction as primary
if '1m' in outputs['next_candles']:
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
# Predicted close is the 4th value (index 3)
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
actual_close_change = batch['future_prices'] # Actual price change ratio
# Check if direction matches (both positive or both negative)
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
candle_accuracy = direction_match.mean().item()
# Extract values and delete tensors to free memory
result = {
@@ -1274,6 +1288,7 @@ class TradingTransformerTrainer:
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
'learning_rate': self.scheduler.get_last_lr()[0]
}