current PnL in models
This commit is contained in:
@@ -479,7 +479,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
|
||||
def forward(self, price_data: torch.Tensor, cob_data: torch.Tensor,
|
||||
tech_data: torch.Tensor, market_data: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
position_state: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the trading transformer
|
||||
|
||||
@@ -489,6 +490,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
tech_data: (batch, seq_len, tech_features) - Technical indicators
|
||||
market_data: (batch, seq_len, market_features) - Market microstructure
|
||||
mask: Optional attention mask
|
||||
position_state: (batch, 5) - Position state [has_position, pnl, size, entry_price, time_in_position]
|
||||
|
||||
Returns:
|
||||
Dictionary containing model outputs
|
||||
@@ -512,6 +514,22 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
# Combine embeddings (could also use cross-attention)
|
||||
x = price_emb + cob_emb + tech_emb + market_emb
|
||||
|
||||
# Add position state if provided - critical for loss minimization and profit taking
|
||||
if position_state is not None:
|
||||
# Project position state to model dimension and add to all sequence positions
|
||||
# This allows the model to condition all predictions on current position state
|
||||
position_emb = torch.tanh(position_state) # Normalize to [-1, 1]
|
||||
position_emb = position_emb.unsqueeze(1).expand(batch_size, seq_len, -1) # (batch, seq_len, 5)
|
||||
|
||||
# Pad to match model dimension if needed
|
||||
if position_emb.size(-1) < self.config.d_model:
|
||||
padding = torch.zeros(batch_size, seq_len, self.config.d_model - position_emb.size(-1),
|
||||
device=position_emb.device, dtype=position_emb.dtype)
|
||||
position_emb = torch.cat([position_emb, padding], dim=-1)
|
||||
|
||||
# Add position state as a bias to the embeddings
|
||||
x = x + position_emb[:, :, :self.config.d_model]
|
||||
|
||||
# Add positional encoding
|
||||
if isinstance(self.pos_encoding, RelativePositionalEncoding):
|
||||
# Relative position encoding is applied in attention
|
||||
@@ -951,16 +969,18 @@ class TradingTransformerTrainer:
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
|
||||
# This ensures each batch is independent and prevents gradient computation errors
|
||||
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
|
||||
# Move batch to device WITHOUT cloning to avoid version tracking issues
|
||||
# The detach().clone() was causing gradient computation errors
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Forward pass
|
||||
# Forward pass with position state for loss minimization
|
||||
outputs = self.model(
|
||||
batch['price_data'],
|
||||
batch['cob_data'],
|
||||
batch['tech_data'],
|
||||
batch['market_data']
|
||||
batch['market_data'],
|
||||
position_state=batch.get('position_state', None) # Pass position state if available
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
@@ -1002,7 +1022,21 @@ class TradingTransformerTrainer:
|
||||
total_loss = total_loss + 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
try:
|
||||
total_loss.backward()
|
||||
except RuntimeError as e:
|
||||
if "inplace operation" in str(e):
|
||||
logger.error(f"Inplace operation error during backward pass: {e}")
|
||||
# Return zero loss to continue training
|
||||
return {
|
||||
'total_loss': 0.0,
|
||||
'action_loss': 0.0,
|
||||
'price_loss': 0.0,
|
||||
'accuracy': 0.0,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
else:
|
||||
raise
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
|
||||
Reference in New Issue
Block a user