reduce logging. actual training
This commit is contained in:
@@ -947,73 +947,92 @@ class TradingTransformerTrainer:
|
||||
|
||||
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Single training step"""
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
|
||||
# This ensures each batch is independent and prevents gradient computation errors
|
||||
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(
|
||||
batch['price_data'],
|
||||
batch['cob_data'],
|
||||
batch['tech_data'],
|
||||
batch['market_data']
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
# Both tensors should have shape [batch_size, 1]
|
||||
# confidence: [batch_size, 1] from confidence_head
|
||||
# trade_success: [batch_size, 1] from batch preparation
|
||||
confidence_pred = outputs['confidence']
|
||||
trade_target = batch['trade_success'].float()
|
||||
try:
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Ensure both have shape [batch_size, 1] for BCELoss
|
||||
# BCELoss requires exact shape match
|
||||
if trade_target.dim() == 1:
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
if confidence_pred.dim() == 1:
|
||||
confidence_pred = confidence_pred.unsqueeze(-1)
|
||||
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
|
||||
# This ensures each batch is independent and prevents gradient computation errors
|
||||
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
|
||||
|
||||
# Final shape verification
|
||||
if confidence_pred.shape != trade_target.shape:
|
||||
# Force reshape to match
|
||||
trade_target = trade_target.view(confidence_pred.shape)
|
||||
# Forward pass
|
||||
outputs = self.model(
|
||||
batch['price_data'],
|
||||
batch['cob_data'],
|
||||
batch['tech_data'],
|
||||
batch['market_data']
|
||||
)
|
||||
|
||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||
# Use addition instead of += to avoid inplace operation
|
||||
total_loss = total_loss + 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
return {
|
||||
'total_loss': total_loss.item(),
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
# Calculate losses
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
# Both tensors should have shape [batch_size, 1] for BCELoss
|
||||
confidence_pred = outputs['confidence']
|
||||
trade_target = batch['trade_success'].float()
|
||||
|
||||
# Ensure both are 2D tensors [batch_size, 1]
|
||||
# Handle different input shapes robustly
|
||||
if confidence_pred.dim() == 0:
|
||||
# Scalar -> [1, 1]
|
||||
confidence_pred = confidence_pred.unsqueeze(0).unsqueeze(0)
|
||||
elif confidence_pred.dim() == 1:
|
||||
# [batch_size] -> [batch_size, 1]
|
||||
confidence_pred = confidence_pred.unsqueeze(-1)
|
||||
|
||||
if trade_target.dim() == 0:
|
||||
# Scalar -> [1, 1]
|
||||
trade_target = trade_target.unsqueeze(0).unsqueeze(0)
|
||||
elif trade_target.dim() == 1:
|
||||
# [batch_size] -> [batch_size, 1]
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
|
||||
# Ensure shapes match exactly - BCELoss requires exact match
|
||||
if confidence_pred.shape != trade_target.shape:
|
||||
# Reshape trade_target to match confidence_pred shape
|
||||
trade_target = trade_target.view(confidence_pred.shape)
|
||||
|
||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||
# Use addition instead of += to avoid inplace operation
|
||||
total_loss = total_loss + 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
return {
|
||||
'total_loss': total_loss.item(),
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in train_step: {e}", exc_info=True)
|
||||
# Return a zero loss dict to prevent training from crashing
|
||||
# but log the error so we can debug
|
||||
return {
|
||||
'total_loss': 0.0,
|
||||
'action_loss': 0.0,
|
||||
'price_loss': 0.0,
|
||||
'accuracy': 0.0,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
|
||||
}
|
||||
|
||||
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
|
||||
"""Validation step"""
|
||||
|
||||
Reference in New Issue
Block a user