beef up DQN model, fix training issues

This commit is contained in:
Dobromir Popov
2025-07-27 20:48:44 +03:00
parent 1894d453c9
commit bd986f4534
6 changed files with 414 additions and 55 deletions

View File

@ -23,8 +23,9 @@ logger = logging.getLogger(__name__)
class DQNNetwork(nn.Module):
"""
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
Handles 7850 input features from multi-timeframe, multi-asset data
TARGET: 50M parameters for enhanced learning capacity
"""
def __init__(self, input_dim: int, n_actions: int):
super(DQNNetwork, self).__init__()
@ -40,36 +41,102 @@ class DQNNetwork(nn.Module):
self.n_actions = n_actions
# Deep network architecture optimized for trading features
self.network = nn.Sequential(
# Input layer
nn.Linear(self.input_size, 2048),
nn.ReLU(),
nn.Dropout(0.3),
# MASSIVE network architecture optimized for trading features
# Target: ~50M parameters
self.feature_extractor = nn.Sequential(
# Initial feature extraction with massive width
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
nn.LayerNorm(8192),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
# Hidden layers with residual-like connections
# Deep feature processing layers
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
nn.LayerNorm(6144),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
nn.LayerNorm(4096),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
nn.LayerNorm(3072),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
nn.LayerNorm(2048),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
# Market regime detection head
self.regime_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
# Output layer for Q-values
nn.Linear(128, n_actions)
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 4) # trending, ranging, volatile, mixed
)
# Price prediction head
self.price_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 3) # short, medium, long term price direction
)
# Volatility prediction head
self.volatility_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(inplace=True),
nn.Linear(256, 1) # predicted volatility
)
# Main Q-value head (dueling architecture)
self.value_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 1) # State value
)
self.advantage_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, n_actions) # Action advantages
)
# Initialize weights
self._initialize_weights()
# Log parameter count
total_params = sum(p.numel() for p in self.parameters())
logger.info(f"DQN Network initialized with {total_params:,} parameters (target: 50M)")
def _initialize_weights(self):
"""Initialize network weights using Xavier initialization"""
@ -78,6 +145,9 @@ class DQNNetwork(nn.Module):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, x):
"""Forward pass through the network"""
@ -87,7 +157,22 @@ class DQNNetwork(nn.Module):
elif x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension if needed
return self.network(x)
# Feature extraction
features = self.feature_extractor(x)
# Multiple prediction heads
regime_pred = self.regime_head(features)
price_pred = self.price_head(features)
volatility_pred = self.volatility_head(features)
# Dueling Q-network
value = self.value_head(features)
advantage = self.advantage_head(features)
# Combine value and advantage for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, regime_pred, price_pred, volatility_pred, features
def act(self, state, explore=True):
"""
@ -111,7 +196,7 @@ class DQNNetwork(nn.Module):
state = state.unsqueeze(0)
with torch.no_grad():
q_values = self.forward(state)
q_values, regime_pred, price_pred, volatility_pred, features = self.forward(state)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
@ -1010,22 +1095,34 @@ class DQNAgent:
logger.warning("Empty batch in _replay_standard")
return 0.0
# Get current Q values using safe wrapper
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Ensure model is in training mode for gradients
self.policy_net.train()
# Get current Q values - use the updated forward method
q_values_output = self.policy_net(states)
if isinstance(q_values_output, tuple):
current_q_values_all = q_values_output[0] # Extract Q-values from tuple
else:
current_q_values_all = q_values_output
current_q_values = current_q_values_all.gather(1, actions.unsqueeze(1)).squeeze(1)
# Enhanced Double DQN implementation
with torch.no_grad():
if self.use_double_dqn:
# Double DQN: Use policy network to select actions, target network to evaluate
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
policy_output = self.policy_net(next_states)
policy_q_values = policy_output[0] if isinstance(policy_output, tuple) else policy_output
next_actions = policy_q_values.argmax(1)
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
target_output = self.target_net(next_states)
target_q_values_all = target_output[0] if isinstance(target_output, tuple) else target_output
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN: Use target network for both selection and evaluation
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = next_q_values.max(1)[0]
target_output = self.target_net(next_states)
target_q_values = target_output[0] if isinstance(target_output, tuple) else target_output
next_q_values = target_q_values.max(1)[0]
# Ensure tensor shapes are consistent
batch_size = states.shape[0]
@ -1043,26 +1140,15 @@ class DQNAgent:
# Compute loss for Q value - ensure tensors require gradients
if not current_q_values.requires_grad:
logger.warning("Current Q values do not require gradients")
# Force training mode
self.policy_net.train()
return 0.0
q_loss = self.criterion(current_q_values, target_q_values.detach())
# Initialize total loss with Q loss
# Use only Q-loss for now to ensure clean gradients
total_loss = q_loss
# Add auxiliary losses if available and valid
try:
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
# Create simple extrema targets based on Q-values
with torch.no_grad():
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2 # Default to "neither"
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
total_loss = total_loss + 0.1 * extrema_loss
except Exception as e:
logger.debug(f"Could not calculate auxiliary loss: {e}")
# Reset gradients
self.optimizer.zero_grad()