reduce logging. actual training

This commit is contained in:
Dobromir Popov
2025-10-31 03:52:41 +02:00
parent 6ac324289c
commit 1bf41e06a8
9 changed files with 1700 additions and 74 deletions

View File

@@ -947,73 +947,92 @@ class TradingTransformerTrainer:
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Single training step"""
self.model.train()
self.optimizer.zero_grad()
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
# This ensures each batch is independent and prevents gradient computation errors
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
# Forward pass
outputs = self.model(
batch['price_data'],
batch['cob_data'],
batch['tech_data'],
batch['market_data']
)
# Calculate losses
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
# Start with base losses - avoid inplace operations on computation graph
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
# Add confidence loss if available
if 'confidence' in outputs and 'trade_success' in batch:
# Both tensors should have shape [batch_size, 1]
# confidence: [batch_size, 1] from confidence_head
# trade_success: [batch_size, 1] from batch preparation
confidence_pred = outputs['confidence']
trade_target = batch['trade_success'].float()
try:
self.model.train()
self.optimizer.zero_grad()
# Ensure both have shape [batch_size, 1] for BCELoss
# BCELoss requires exact shape match
if trade_target.dim() == 1:
trade_target = trade_target.unsqueeze(-1)
if confidence_pred.dim() == 1:
confidence_pred = confidence_pred.unsqueeze(-1)
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
# This ensures each batch is independent and prevents gradient computation errors
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
# Final shape verification
if confidence_pred.shape != trade_target.shape:
# Force reshape to match
trade_target = trade_target.view(confidence_pred.shape)
# Forward pass
outputs = self.model(
batch['price_data'],
batch['cob_data'],
batch['tech_data'],
batch['market_data']
)
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
# Use addition instead of += to avoid inplace operation
total_loss = total_loss + 0.1 * confidence_loss
# Backward pass
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Optimizer step
self.optimizer.step()
self.scheduler.step()
# Calculate accuracy
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
return {
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'accuracy': accuracy.item(),
'learning_rate': self.scheduler.get_last_lr()[0]
}
# Calculate losses
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
# Start with base losses - avoid inplace operations on computation graph
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
# Add confidence loss if available
if 'confidence' in outputs and 'trade_success' in batch:
# Both tensors should have shape [batch_size, 1] for BCELoss
confidence_pred = outputs['confidence']
trade_target = batch['trade_success'].float()
# Ensure both are 2D tensors [batch_size, 1]
# Handle different input shapes robustly
if confidence_pred.dim() == 0:
# Scalar -> [1, 1]
confidence_pred = confidence_pred.unsqueeze(0).unsqueeze(0)
elif confidence_pred.dim() == 1:
# [batch_size] -> [batch_size, 1]
confidence_pred = confidence_pred.unsqueeze(-1)
if trade_target.dim() == 0:
# Scalar -> [1, 1]
trade_target = trade_target.unsqueeze(0).unsqueeze(0)
elif trade_target.dim() == 1:
# [batch_size] -> [batch_size, 1]
trade_target = trade_target.unsqueeze(-1)
# Ensure shapes match exactly - BCELoss requires exact match
if confidence_pred.shape != trade_target.shape:
# Reshape trade_target to match confidence_pred shape
trade_target = trade_target.view(confidence_pred.shape)
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
# Use addition instead of += to avoid inplace operation
total_loss = total_loss + 0.1 * confidence_loss
# Backward pass
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Optimizer step
self.optimizer.step()
self.scheduler.step()
# Calculate accuracy
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
return {
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'accuracy': accuracy.item(),
'learning_rate': self.scheduler.get_last_lr()[0]
}
except Exception as e:
logger.error(f"Error in train_step: {e}", exc_info=True)
# Return a zero loss dict to prevent training from crashing
# but log the error so we can debug
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
}
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
"""Validation step"""