beef up DQN model, fix training issues
This commit is contained in:
@ -23,8 +23,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DQNNetwork(nn.Module):
|
||||
"""
|
||||
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||
TARGET: 50M parameters for enhanced learning capacity
|
||||
"""
|
||||
def __init__(self, input_dim: int, n_actions: int):
|
||||
super(DQNNetwork, self).__init__()
|
||||
@ -40,36 +41,102 @@ class DQNNetwork(nn.Module):
|
||||
|
||||
self.n_actions = n_actions
|
||||
|
||||
# Deep network architecture optimized for trading features
|
||||
self.network = nn.Sequential(
|
||||
# Input layer
|
||||
nn.Linear(self.input_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
# MASSIVE network architecture optimized for trading features
|
||||
# Target: ~50M parameters
|
||||
self.feature_extractor = nn.Sequential(
|
||||
# Initial feature extraction with massive width
|
||||
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
|
||||
nn.LayerNorm(8192),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
# Hidden layers with residual-like connections
|
||||
# Deep feature processing layers
|
||||
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
|
||||
nn.LayerNorm(6144),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
|
||||
nn.LayerNorm(4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
|
||||
nn.LayerNorm(3072),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
|
||||
nn.LayerNorm(2048),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
|
||||
# Market regime detection head
|
||||
self.regime_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Output layer for Q-values
|
||||
nn.Linear(128, n_actions)
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 4) # trending, ranging, volatile, mixed
|
||||
)
|
||||
|
||||
# Price prediction head
|
||||
self.price_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 3) # short, medium, long term price direction
|
||||
)
|
||||
|
||||
# Volatility prediction head
|
||||
self.volatility_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 256),
|
||||
nn.LayerNorm(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(256, 1) # predicted volatility
|
||||
)
|
||||
|
||||
# Main Q-value head (dueling architecture)
|
||||
self.value_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 1) # State value
|
||||
)
|
||||
|
||||
self.advantage_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, n_actions) # Action advantages
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
# Log parameter count
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
logger.info(f"DQN Network initialized with {total_params:,} parameters (target: 50M)")
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize network weights using Xavier initialization"""
|
||||
@ -78,6 +145,9 @@ class DQNNetwork(nn.Module):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the network"""
|
||||
@ -87,7 +157,22 @@ class DQNNetwork(nn.Module):
|
||||
elif x.dim() == 1:
|
||||
x = x.unsqueeze(0) # Add batch dimension if needed
|
||||
|
||||
return self.network(x)
|
||||
# Feature extraction
|
||||
features = self.feature_extractor(x)
|
||||
|
||||
# Multiple prediction heads
|
||||
regime_pred = self.regime_head(features)
|
||||
price_pred = self.price_head(features)
|
||||
volatility_pred = self.volatility_head(features)
|
||||
|
||||
# Dueling Q-network
|
||||
value = self.value_head(features)
|
||||
advantage = self.advantage_head(features)
|
||||
|
||||
# Combine value and advantage for Q-values
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
return q_values, regime_pred, price_pred, volatility_pred, features
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""
|
||||
@ -111,7 +196,7 @@ class DQNNetwork(nn.Module):
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values = self.forward(state)
|
||||
q_values, regime_pred, price_pred, volatility_pred, features = self.forward(state)
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
@ -1010,22 +1095,34 @@ class DQNAgent:
|
||||
logger.warning("Empty batch in _replay_standard")
|
||||
return 0.0
|
||||
|
||||
# Get current Q values using safe wrapper
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
# Ensure model is in training mode for gradients
|
||||
self.policy_net.train()
|
||||
|
||||
# Get current Q values - use the updated forward method
|
||||
q_values_output = self.policy_net(states)
|
||||
if isinstance(q_values_output, tuple):
|
||||
current_q_values_all = q_values_output[0] # Extract Q-values from tuple
|
||||
else:
|
||||
current_q_values_all = q_values_output
|
||||
|
||||
current_q_values = current_q_values_all.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Enhanced Double DQN implementation
|
||||
with torch.no_grad():
|
||||
if self.use_double_dqn:
|
||||
# Double DQN: Use policy network to select actions, target network to evaluate
|
||||
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
||||
policy_output = self.policy_net(next_states)
|
||||
policy_q_values = policy_output[0] if isinstance(policy_output, tuple) else policy_output
|
||||
next_actions = policy_q_values.argmax(1)
|
||||
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
|
||||
target_output = self.target_net(next_states)
|
||||
target_q_values_all = target_output[0] if isinstance(target_output, tuple) else target_output
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN: Use target network for both selection and evaluation
|
||||
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_output = self.target_net(next_states)
|
||||
target_q_values = target_output[0] if isinstance(target_output, tuple) else target_output
|
||||
next_q_values = target_q_values.max(1)[0]
|
||||
|
||||
# Ensure tensor shapes are consistent
|
||||
batch_size = states.shape[0]
|
||||
@ -1043,26 +1140,15 @@ class DQNAgent:
|
||||
# Compute loss for Q value - ensure tensors require gradients
|
||||
if not current_q_values.requires_grad:
|
||||
logger.warning("Current Q values do not require gradients")
|
||||
# Force training mode
|
||||
self.policy_net.train()
|
||||
return 0.0
|
||||
|
||||
q_loss = self.criterion(current_q_values, target_q_values.detach())
|
||||
|
||||
# Initialize total loss with Q loss
|
||||
# Use only Q-loss for now to ensure clean gradients
|
||||
total_loss = q_loss
|
||||
|
||||
# Add auxiliary losses if available and valid
|
||||
try:
|
||||
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
|
||||
# Create simple extrema targets based on Q-values
|
||||
with torch.no_grad():
|
||||
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2 # Default to "neither"
|
||||
|
||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
||||
total_loss = total_loss + 0.1 * extrema_loss
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not calculate auxiliary loss: {e}")
|
||||
|
||||
# Reset gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
|
Reference in New Issue
Block a user