improved model loading and training
This commit is contained in:
@@ -625,7 +625,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
'calculated_angle': trend_angles.unsqueeze(-1), # (batch, 1)
|
||||
'calculated_steepness': trend_steepness.unsqueeze(-1), # (batch, 1)
|
||||
'calculated_direction': trend_direction.unsqueeze(-1), # (batch, 1)
|
||||
'vector': torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=1) # (batch, 2) - [price_delta, time_delta]
|
||||
'vector': torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=0).unsqueeze(0) if batch_size == 1 else torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=1) # (batch, 2) - [price_delta, time_delta]
|
||||
}
|
||||
else:
|
||||
outputs['trend_vector'] = {
|
||||
@@ -663,8 +663,13 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
|
||||
# Calculate action probabilities based on trend
|
||||
for i in range(batch_size):
|
||||
angle = trend_angle[i].item() if batch_size > 0 else 0.0
|
||||
steep = trend_steepness_val[i].item() if batch_size > 0 else 0.0
|
||||
# Handle both 0-dim and 1-dim tensors
|
||||
if trend_angle.dim() == 0:
|
||||
angle = trend_angle.item()
|
||||
steep = trend_steepness_val.item()
|
||||
else:
|
||||
angle = trend_angle[i].item()
|
||||
steep = trend_steepness_val[i].item()
|
||||
|
||||
# Normalize steepness to [0, 1] range (assuming max steepness of 10 units)
|
||||
normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0
|
||||
@@ -964,10 +969,19 @@ class TradingTransformerTrainer:
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
confidence_loss = self.confidence_criterion(
|
||||
outputs['confidence'].squeeze(),
|
||||
batch['trade_success'].float()
|
||||
)
|
||||
# Ensure both tensors have compatible shapes
|
||||
# confidence: [batch_size, 1] -> squeeze last dim to [batch_size]
|
||||
# trade_success: [batch_size] - ensure same shape
|
||||
confidence_pred = outputs['confidence'].squeeze(-1) # Only remove last dimension
|
||||
trade_target = batch['trade_success'].float()
|
||||
|
||||
# Ensure shapes match (handle edge case where batch_size=1)
|
||||
if confidence_pred.dim() == 0: # scalar case
|
||||
confidence_pred = confidence_pred.unsqueeze(0)
|
||||
if trade_target.dim() == 0: # scalar case
|
||||
trade_target = trade_target.unsqueeze(0)
|
||||
|
||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||
total_loss += 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
|
||||
Reference in New Issue
Block a user