vector predictions inference fix

This commit is contained in:
Dobromir Popov
2025-07-30 00:32:35 +03:00
parent 8335ad8e64
commit c5a9e75ee7
5 changed files with 33 additions and 17 deletions

View File

@ -169,7 +169,12 @@ class DQNNetwork(nn.Module):
# Combine value and advantage for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, regime_pred, price_direction_pred, volatility_pred, features
# Add placeholder multi-timeframe predictions for compatibility
batch_size = q_values.size(0)
device = q_values.device
multi_timeframe_pred = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each
return q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred
def act(self, state, explore=True):
"""
@ -197,7 +202,7 @@ class DQNNetwork(nn.Module):
state = state.unsqueeze(0)
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.forward(state)
# Price direction predictions are processed in the agent's act method
# This is just the network forward pass
@ -781,7 +786,7 @@ class DQNAgent:
# Process price direction predictions from the network
# Get the raw predictions from the network's forward pass
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state)
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
@ -826,7 +831,7 @@ class DQNAgent:
# Get network outputs
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state_tensor)
# Process price direction predictions
if price_direction_pred is not None:
@ -1025,11 +1030,18 @@ class DQNAgent:
return None
def _safe_cnn_forward(self, network, states):
"""Safely call CNN forward method ensuring we always get 5 return values"""
"""Safely call CNN forward method ensuring we always get 6 return values"""
try:
result = network(states)
if isinstance(result, tuple) and len(result) == 5:
if isinstance(result, tuple) and len(result) == 6:
return result
elif isinstance(result, tuple) and len(result) == 5:
# Handle legacy 5-value return by adding default multi_timeframe_pred
q_values, extrema_pred, price_pred, features, advanced_pred = result
batch_size = q_values.size(0)
device = q_values.device
default_multi_timeframe = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each
return q_values, extrema_pred, price_pred, features, advanced_pred, default_multi_timeframe
elif isinstance(result, tuple) and len(result) == 1:
# Handle case where only q_values are returned (like in empty tensor case)
q_values = result[0]
@ -1039,7 +1051,8 @@ class DQNAgent:
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return q_values, default_extrema, default_price, default_features, default_advanced
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
return q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
else:
# Fallback: create all default tensors
batch_size = states.size(0)
@ -1049,7 +1062,8 @@ class DQNAgent:
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
except Exception as e:
logger.error(f"Error in CNN forward pass: {e}")
# Fallback: create all default tensors
@ -1060,7 +1074,8 @@ class DQNAgent:
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
def replay(self, experiences=None):
"""Train the model using experiences from memory"""
@ -1437,20 +1452,20 @@ class DQNAgent:
warnings.simplefilter("ignore", FutureWarning)
with torch.cuda.amp.autocast():
# Get current Q values and predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred, current_multi_timeframe_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network
with torch.no_grad():
if self.use_double_dqn:
# Double DQN
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
policy_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
next_actions = policy_q_values.argmax(1)
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
target_q_values_all, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = next_q_values.max(1)[0]
# Ensure consistent shapes