vector predictions inference fix
This commit is contained in:
@ -169,7 +169,12 @@ class DQNNetwork(nn.Module):
|
||||
# Combine value and advantage for Q-values
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
return q_values, regime_pred, price_direction_pred, volatility_pred, features
|
||||
# Add placeholder multi-timeframe predictions for compatibility
|
||||
batch_size = q_values.size(0)
|
||||
device = q_values.device
|
||||
multi_timeframe_pred = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each
|
||||
|
||||
return q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""
|
||||
@ -197,7 +202,7 @@ class DQNNetwork(nn.Module):
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.forward(state)
|
||||
|
||||
# Price direction predictions are processed in the agent's act method
|
||||
# This is just the network forward pass
|
||||
@ -781,7 +786,7 @@ class DQNAgent:
|
||||
# Process price direction predictions from the network
|
||||
# Get the raw predictions from the network's forward pass
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state)
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
@ -826,7 +831,7 @@ class DQNAgent:
|
||||
|
||||
# Get network outputs
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
@ -1025,11 +1030,18 @@ class DQNAgent:
|
||||
return None
|
||||
|
||||
def _safe_cnn_forward(self, network, states):
|
||||
"""Safely call CNN forward method ensuring we always get 5 return values"""
|
||||
"""Safely call CNN forward method ensuring we always get 6 return values"""
|
||||
try:
|
||||
result = network(states)
|
||||
if isinstance(result, tuple) and len(result) == 5:
|
||||
if isinstance(result, tuple) and len(result) == 6:
|
||||
return result
|
||||
elif isinstance(result, tuple) and len(result) == 5:
|
||||
# Handle legacy 5-value return by adding default multi_timeframe_pred
|
||||
q_values, extrema_pred, price_pred, features, advanced_pred = result
|
||||
batch_size = q_values.size(0)
|
||||
device = q_values.device
|
||||
default_multi_timeframe = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each
|
||||
return q_values, extrema_pred, price_pred, features, advanced_pred, default_multi_timeframe
|
||||
elif isinstance(result, tuple) and len(result) == 1:
|
||||
# Handle case where only q_values are returned (like in empty tensor case)
|
||||
q_values = result[0]
|
||||
@ -1039,7 +1051,8 @@ class DQNAgent:
|
||||
default_price = torch.zeros(batch_size, 1, device=device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||
return q_values, default_extrema, default_price, default_features, default_advanced
|
||||
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
|
||||
return q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
|
||||
else:
|
||||
# Fallback: create all default tensors
|
||||
batch_size = states.size(0)
|
||||
@ -1049,7 +1062,8 @@ class DQNAgent:
|
||||
default_price = torch.zeros(batch_size, 1, device=device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
||||
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
|
||||
return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN forward pass: {e}")
|
||||
# Fallback: create all default tensors
|
||||
@ -1060,7 +1074,8 @@ class DQNAgent:
|
||||
default_price = torch.zeros(batch_size, 1, device=device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
||||
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
|
||||
return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
|
||||
|
||||
def replay(self, experiences=None):
|
||||
"""Train the model using experiences from memory"""
|
||||
@ -1437,20 +1452,20 @@ class DQNAgent:
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred, current_multi_timeframe_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
if self.use_double_dqn:
|
||||
# Double DQN
|
||||
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
||||
policy_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
||||
next_actions = policy_q_values.argmax(1)
|
||||
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
target_q_values_all, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN
|
||||
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Ensure consistent shapes
|
||||
|
Reference in New Issue
Block a user