This commit is contained in:
Dobromir Popov
2025-11-23 02:16:34 +02:00
parent 24aeefda9d
commit 53ce4a355a
8 changed files with 1088 additions and 155 deletions

View File

@@ -162,7 +162,8 @@ class DeepMultiScaleAttention(nn.Module):
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)
# Use non-inplace version to avoid gradient computation issues
scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1)
attention = self.dropout(attention)
@@ -1089,8 +1090,11 @@ class TradingTransformerTrainer:
pct_start=0.1
)
# Loss functions
self.action_criterion = nn.CrossEntropyLoss()
# Loss functions with class weights
# Pivot-based training: BUY at L pivots, SELL at H pivots (naturally balanced)
# Weights: [HOLD=0, BUY=1, SELL=2] - equal weighting for pivot-based trades
class_weights = torch.tensor([0.5, 1.0, 1.0], dtype=torch.float32, device=self.device)
self.action_criterion = nn.CrossEntropyLoss(weight=class_weights)
self.price_criterion = nn.MSELoss()
self.confidence_criterion = nn.BCELoss()
@@ -1182,19 +1186,30 @@ class TradingTransformerTrainer:
Returns:
Denormalized OHLCV tensor
"""
denorm = normalized_candle.clone()
# Denormalize OHLC (first 4 values)
# Avoid inplace operations by creating new tensors instead of slice assignment
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
# Denormalize volume (5th value)
volume_min = norm_params.get('volume_min', 0.0)
volume_max = norm_params.get('volume_max', 1.0)
# Denormalize OHLC (first 4 values) - create new tensor, no inplace operations
if price_max > price_min:
price_scale = (price_max - price_min)
price_offset = price_min
denorm_ohlc = normalized_candle[..., :4] * price_scale + price_offset
else:
denorm_ohlc = normalized_candle[..., :4]
# Denormalize volume (5th value) - create new tensor, no inplace operations
if volume_max > volume_min:
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
volume_scale = (volume_max - volume_min)
volume_offset = volume_min
denorm_volume = (normalized_candle[..., 4:5] * volume_scale + volume_offset)
else:
denorm_volume = normalized_candle[..., 4:5]
# Concatenate OHLC and Volume to create final tensor (no inplace operations)
denorm = torch.cat([denorm_ohlc, denorm_volume], dim=-1)
return denorm
@@ -1675,9 +1690,46 @@ class TradingTransformerTrainer:
"""Load model and training state"""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Load model state (with strict=False to handle architecture changes)
try:
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
except Exception as e:
logger.warning(f"Error loading model state dict: {e}, continuing with partial load")
# Load optimizer state (handle mismatched states gracefully)
try:
optimizer_state = checkpoint.get('optimizer_state_dict')
if optimizer_state:
try:
# Try to load optimizer state
self.optimizer.load_state_dict(optimizer_state)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
# Recreate optimizer (same pattern as __init__)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
else:
logger.warning("No optimizer state found in checkpoint. Using fresh optimizer.")
except Exception as e:
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
# Recreate optimizer (same pattern as __init__)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Load scheduler state
try:
scheduler_state = checkpoint.get('scheduler_state_dict')
if scheduler_state:
self.scheduler.load_state_dict(scheduler_state)
except Exception as e:
logger.warning(f"Error loading scheduler state: {e}, continuing without scheduler state")
self.training_history = checkpoint.get('training_history', self.training_history)
logger.info(f"Model loaded from {path}")