fixes
This commit is contained in:
@@ -1407,18 +1407,37 @@ class TradingTransformerTrainer:
|
||||
)
|
||||
|
||||
# Calculate losses (use batch_on_device for consistency)
|
||||
# Handle case where actions key is missing (e.g., when no timeframe data available)
|
||||
if 'actions' not in batch_on_device:
|
||||
logger.warning("No 'actions' key in batch - skipping this training step")
|
||||
return {
|
||||
'total_loss': 0.0,
|
||||
'action_loss': 0.0,
|
||||
'price_loss': 0.0,
|
||||
'accuracy': 0.0,
|
||||
'candle_accuracy': 0.0,
|
||||
'trend_accuracy': 0.0,
|
||||
'action_accuracy': 0.0
|
||||
}
|
||||
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
|
||||
|
||||
# FIXED: Ensure shapes match for MSELoss
|
||||
price_pred = outputs['price_prediction']
|
||||
price_target = batch_on_device['future_prices']
|
||||
|
||||
# Both should be [batch, 1], but ensure they match
|
||||
if price_pred.shape != price_target.shape:
|
||||
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
|
||||
price_target = price_target.view(price_pred.shape)
|
||||
|
||||
price_loss = self.price_criterion(price_pred, price_target)
|
||||
# Handle case where future_prices key is missing
|
||||
if 'future_prices' not in batch_on_device:
|
||||
logger.warning("No 'future_prices' key in batch - using zero loss for price prediction")
|
||||
price_loss = torch.tensor(0.0, device=self.device)
|
||||
else:
|
||||
price_target = batch_on_device['future_prices']
|
||||
|
||||
# Both should be [batch, 1], but ensure they match
|
||||
if price_pred.shape != price_target.shape:
|
||||
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
|
||||
price_target = price_target.view(price_pred.shape)
|
||||
|
||||
price_loss = self.price_criterion(price_pred, price_target)
|
||||
|
||||
# NEW: Trend analysis loss (if trend_target provided)
|
||||
trend_loss = torch.tensor(0.0, device=self.device)
|
||||
@@ -1677,8 +1696,11 @@ class TradingTransformerTrainer:
|
||||
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
||||
|
||||
# LEGACY: Action accuracy (for comparison)
|
||||
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
|
||||
if 'actions' in batch_on_device:
|
||||
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
|
||||
else:
|
||||
action_accuracy = 0.0
|
||||
|
||||
# Extract values and delete tensors to free memory
|
||||
result = {
|
||||
|
||||
Reference in New Issue
Block a user