vector predictions inference fix
This commit is contained in:
@ -169,7 +169,12 @@ class DQNNetwork(nn.Module):
|
|||||||
# Combine value and advantage for Q-values
|
# Combine value and advantage for Q-values
|
||||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
return q_values, regime_pred, price_direction_pred, volatility_pred, features
|
# Add placeholder multi-timeframe predictions for compatibility
|
||||||
|
batch_size = q_values.size(0)
|
||||||
|
device = q_values.device
|
||||||
|
multi_timeframe_pred = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each
|
||||||
|
|
||||||
|
return q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred
|
||||||
|
|
||||||
def act(self, state, explore=True):
|
def act(self, state, explore=True):
|
||||||
"""
|
"""
|
||||||
@ -197,7 +202,7 @@ class DQNNetwork(nn.Module):
|
|||||||
state = state.unsqueeze(0)
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.forward(state)
|
||||||
|
|
||||||
# Price direction predictions are processed in the agent's act method
|
# Price direction predictions are processed in the agent's act method
|
||||||
# This is just the network forward pass
|
# This is just the network forward pass
|
||||||
@ -781,7 +786,7 @@ class DQNAgent:
|
|||||||
# Process price direction predictions from the network
|
# Process price direction predictions from the network
|
||||||
# Get the raw predictions from the network's forward pass
|
# Get the raw predictions from the network's forward pass
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
|
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state)
|
||||||
if price_direction_pred is not None:
|
if price_direction_pred is not None:
|
||||||
self.process_price_direction_predictions(price_direction_pred)
|
self.process_price_direction_predictions(price_direction_pred)
|
||||||
|
|
||||||
@ -826,7 +831,7 @@ class DQNAgent:
|
|||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
|
q_values, regime_pred, price_direction_pred, volatility_pred, features, multi_timeframe_pred = self.policy_net.forward(state_tensor)
|
||||||
|
|
||||||
# Process price direction predictions
|
# Process price direction predictions
|
||||||
if price_direction_pred is not None:
|
if price_direction_pred is not None:
|
||||||
@ -1025,11 +1030,18 @@ class DQNAgent:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _safe_cnn_forward(self, network, states):
|
def _safe_cnn_forward(self, network, states):
|
||||||
"""Safely call CNN forward method ensuring we always get 5 return values"""
|
"""Safely call CNN forward method ensuring we always get 6 return values"""
|
||||||
try:
|
try:
|
||||||
result = network(states)
|
result = network(states)
|
||||||
if isinstance(result, tuple) and len(result) == 5:
|
if isinstance(result, tuple) and len(result) == 6:
|
||||||
return result
|
return result
|
||||||
|
elif isinstance(result, tuple) and len(result) == 5:
|
||||||
|
# Handle legacy 5-value return by adding default multi_timeframe_pred
|
||||||
|
q_values, extrema_pred, price_pred, features, advanced_pred = result
|
||||||
|
batch_size = q_values.size(0)
|
||||||
|
device = q_values.device
|
||||||
|
default_multi_timeframe = torch.zeros(batch_size, 12, device=device) # 3 timeframes * 4 values each
|
||||||
|
return q_values, extrema_pred, price_pred, features, advanced_pred, default_multi_timeframe
|
||||||
elif isinstance(result, tuple) and len(result) == 1:
|
elif isinstance(result, tuple) and len(result) == 1:
|
||||||
# Handle case where only q_values are returned (like in empty tensor case)
|
# Handle case where only q_values are returned (like in empty tensor case)
|
||||||
q_values = result[0]
|
q_values = result[0]
|
||||||
@ -1039,7 +1051,8 @@ class DQNAgent:
|
|||||||
default_price = torch.zeros(batch_size, 1, device=device)
|
default_price = torch.zeros(batch_size, 1, device=device)
|
||||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||||
return q_values, default_extrema, default_price, default_features, default_advanced
|
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
|
||||||
|
return q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
|
||||||
else:
|
else:
|
||||||
# Fallback: create all default tensors
|
# Fallback: create all default tensors
|
||||||
batch_size = states.size(0)
|
batch_size = states.size(0)
|
||||||
@ -1049,7 +1062,8 @@ class DQNAgent:
|
|||||||
default_price = torch.zeros(batch_size, 1, device=device)
|
default_price = torch.zeros(batch_size, 1, device=device)
|
||||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||||
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
|
||||||
|
return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in CNN forward pass: {e}")
|
logger.error(f"Error in CNN forward pass: {e}")
|
||||||
# Fallback: create all default tensors
|
# Fallback: create all default tensors
|
||||||
@ -1060,7 +1074,8 @@ class DQNAgent:
|
|||||||
default_price = torch.zeros(batch_size, 1, device=device)
|
default_price = torch.zeros(batch_size, 1, device=device)
|
||||||
default_features = torch.zeros(batch_size, 1024, device=device)
|
default_features = torch.zeros(batch_size, 1024, device=device)
|
||||||
default_advanced = torch.zeros(batch_size, 1, device=device)
|
default_advanced = torch.zeros(batch_size, 1, device=device)
|
||||||
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
default_multi_timeframe = torch.zeros(batch_size, 12, device=device)
|
||||||
|
return default_q_values, default_extrema, default_price, default_features, default_advanced, default_multi_timeframe
|
||||||
|
|
||||||
def replay(self, experiences=None):
|
def replay(self, experiences=None):
|
||||||
"""Train the model using experiences from memory"""
|
"""Train the model using experiences from memory"""
|
||||||
@ -1437,20 +1452,20 @@ class DQNAgent:
|
|||||||
warnings.simplefilter("ignore", FutureWarning)
|
warnings.simplefilter("ignore", FutureWarning)
|
||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
# Get current Q values and predictions
|
# Get current Q values and predictions
|
||||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred, current_multi_timeframe_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
# Get next Q values from target network
|
# Get next Q values from target network
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.use_double_dqn:
|
if self.use_double_dqn:
|
||||||
# Double DQN
|
# Double DQN
|
||||||
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
policy_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
||||||
next_actions = policy_q_values.argmax(1)
|
next_actions = policy_q_values.argmax(1)
|
||||||
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
target_q_values_all, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||||
else:
|
else:
|
||||||
# Standard DQN
|
# Standard DQN
|
||||||
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
next_q_values, _, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||||
next_q_values = next_q_values.max(1)[0]
|
next_q_values = next_q_values.max(1)[0]
|
||||||
|
|
||||||
# Ensure consistent shapes
|
# Ensure consistent shapes
|
||||||
|
@ -162,7 +162,7 @@ class StandardizedCNN(nn.Module):
|
|||||||
cnn_input = processed_features.unsqueeze(1) # Add sequence dimension
|
cnn_input = processed_features.unsqueeze(1) # Add sequence dimension
|
||||||
|
|
||||||
try:
|
try:
|
||||||
q_values, extrema_pred, price_pred, cnn_features, advanced_pred = self.enhanced_cnn(cnn_input)
|
q_values, extrema_pred, price_pred, cnn_features, advanced_pred, multi_timeframe_pred = self.enhanced_cnn(cnn_input)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Enhanced CNN forward pass failed: {e}, using fallback")
|
logger.warning(f"Enhanced CNN forward pass failed: {e}, using fallback")
|
||||||
# Fallback to direct processing
|
# Fallback to direct processing
|
||||||
|
@ -4841,6 +4841,7 @@ class TradingOrchestrator:
|
|||||||
price_pred,
|
price_pred,
|
||||||
features_refined,
|
features_refined,
|
||||||
advanced_pred,
|
advanced_pred,
|
||||||
|
multi_timeframe_pred,
|
||||||
) = self.cnn_model(features_tensor)
|
) = self.cnn_model(features_tensor)
|
||||||
|
|
||||||
# Convert to probabilities using softmax
|
# Convert to probabilities using softmax
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
},
|
},
|
||||||
"decision_fusion": {
|
"decision_fusion": {
|
||||||
"inference_enabled": false,
|
"inference_enabled": false,
|
||||||
"training_enabled": true
|
"training_enabled": false
|
||||||
},
|
},
|
||||||
"transformer": {
|
"transformer": {
|
||||||
"inference_enabled": false,
|
"inference_enabled": false,
|
||||||
@ -25,5 +25,5 @@
|
|||||||
"training_enabled": true
|
"training_enabled": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"timestamp": "2025-07-29T23:33:51.882579"
|
"timestamp": "2025-07-30T00:17:57.738273"
|
||||||
}
|
}
|
@ -7266,7 +7266,7 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
# Get prediction from CNN model
|
# Get prediction from CNN model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_adapter(features_tensor)
|
q_values, extrema_pred, price_pred, features_refined, advanced_pred, multi_timeframe_pred = self.cnn_adapter(features_tensor)
|
||||||
|
|
||||||
# Convert to probabilities using softmax
|
# Convert to probabilities using softmax
|
||||||
action_probs = torch.softmax(q_values, dim=1)
|
action_probs = torch.softmax(q_values, dim=1)
|
||||||
|
Reference in New Issue
Block a user