wip
This commit is contained in:
@@ -162,7 +162,8 @@ class DeepMultiScaleAttention(nn.Module):
|
||||
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if mask is not None:
|
||||
scores.masked_fill_(mask == 0, -1e9)
|
||||
# Use non-inplace version to avoid gradient computation issues
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
|
||||
attention = F.softmax(scores, dim=-1)
|
||||
attention = self.dropout(attention)
|
||||
@@ -1089,8 +1090,11 @@ class TradingTransformerTrainer:
|
||||
pct_start=0.1
|
||||
)
|
||||
|
||||
# Loss functions
|
||||
self.action_criterion = nn.CrossEntropyLoss()
|
||||
# Loss functions with class weights
|
||||
# Pivot-based training: BUY at L pivots, SELL at H pivots (naturally balanced)
|
||||
# Weights: [HOLD=0, BUY=1, SELL=2] - equal weighting for pivot-based trades
|
||||
class_weights = torch.tensor([0.5, 1.0, 1.0], dtype=torch.float32, device=self.device)
|
||||
self.action_criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||
self.price_criterion = nn.MSELoss()
|
||||
self.confidence_criterion = nn.BCELoss()
|
||||
|
||||
@@ -1182,19 +1186,30 @@ class TradingTransformerTrainer:
|
||||
Returns:
|
||||
Denormalized OHLCV tensor
|
||||
"""
|
||||
denorm = normalized_candle.clone()
|
||||
|
||||
# Denormalize OHLC (first 4 values)
|
||||
# Avoid inplace operations by creating new tensors instead of slice assignment
|
||||
price_min = norm_params.get('price_min', 0.0)
|
||||
price_max = norm_params.get('price_max', 1.0)
|
||||
if price_max > price_min:
|
||||
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
|
||||
|
||||
# Denormalize volume (5th value)
|
||||
volume_min = norm_params.get('volume_min', 0.0)
|
||||
volume_max = norm_params.get('volume_max', 1.0)
|
||||
|
||||
# Denormalize OHLC (first 4 values) - create new tensor, no inplace operations
|
||||
if price_max > price_min:
|
||||
price_scale = (price_max - price_min)
|
||||
price_offset = price_min
|
||||
denorm_ohlc = normalized_candle[..., :4] * price_scale + price_offset
|
||||
else:
|
||||
denorm_ohlc = normalized_candle[..., :4]
|
||||
|
||||
# Denormalize volume (5th value) - create new tensor, no inplace operations
|
||||
if volume_max > volume_min:
|
||||
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
|
||||
volume_scale = (volume_max - volume_min)
|
||||
volume_offset = volume_min
|
||||
denorm_volume = (normalized_candle[..., 4:5] * volume_scale + volume_offset)
|
||||
else:
|
||||
denorm_volume = normalized_candle[..., 4:5]
|
||||
|
||||
# Concatenate OHLC and Volume to create final tensor (no inplace operations)
|
||||
denorm = torch.cat([denorm_ohlc, denorm_volume], dim=-1)
|
||||
|
||||
return denorm
|
||||
|
||||
@@ -1675,9 +1690,46 @@ class TradingTransformerTrainer:
|
||||
"""Load model and training state"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
# Load model state (with strict=False to handle architecture changes)
|
||||
try:
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading model state dict: {e}, continuing with partial load")
|
||||
|
||||
# Load optimizer state (handle mismatched states gracefully)
|
||||
try:
|
||||
optimizer_state = checkpoint.get('optimizer_state_dict')
|
||||
if optimizer_state:
|
||||
try:
|
||||
# Try to load optimizer state
|
||||
self.optimizer.load_state_dict(optimizer_state)
|
||||
except (KeyError, ValueError, RuntimeError) as e:
|
||||
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
|
||||
# Recreate optimizer (same pattern as __init__)
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay
|
||||
)
|
||||
else:
|
||||
logger.warning("No optimizer state found in checkpoint. Using fresh optimizer.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
|
||||
# Recreate optimizer (same pattern as __init__)
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay
|
||||
)
|
||||
|
||||
# Load scheduler state
|
||||
try:
|
||||
scheduler_state = checkpoint.get('scheduler_state_dict')
|
||||
if scheduler_state:
|
||||
self.scheduler.load_state_dict(scheduler_state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading scheduler state: {e}, continuing without scheduler state")
|
||||
|
||||
self.training_history = checkpoint.get('training_history', self.training_history)
|
||||
|
||||
logger.info(f"Model loaded from {path}")
|
||||
|
||||
Reference in New Issue
Block a user