fx T checkpoint and model loss measure
This commit is contained in:
@@ -1267,6 +1267,20 @@ class TradingTransformerTrainer:
|
||||
with torch.no_grad():
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
# Calculate candle prediction accuracy (price direction)
|
||||
candle_accuracy = 0.0
|
||||
if 'next_candles' in outputs and 'future_prices' in batch:
|
||||
# Use 1m timeframe prediction as primary
|
||||
if '1m' in outputs['next_candles']:
|
||||
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
# Predicted close is the 4th value (index 3)
|
||||
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
|
||||
actual_close_change = batch['future_prices'] # Actual price change ratio
|
||||
|
||||
# Check if direction matches (both positive or both negative)
|
||||
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
|
||||
candle_accuracy = direction_match.mean().item()
|
||||
|
||||
# Extract values and delete tensors to free memory
|
||||
result = {
|
||||
@@ -1274,6 +1288,7 @@ class TradingTransformerTrainer:
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'candle_accuracy': candle_accuracy,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user