predict price direction
This commit is contained in:
@ -265,8 +265,9 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||
)
|
||||
|
||||
# ULTRA MASSIVE multi-timeframe price prediction heads
|
||||
self.price_pred_immediate = nn.Sequential(
|
||||
# ULTRA MASSIVE price direction prediction head
|
||||
# Outputs single direction and confidence values
|
||||
self.price_direction_head = nn.Sequential(
|
||||
nn.Linear(1024, 1024), # Increased from 512
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
@ -275,32 +276,13 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), # Increased from 128
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3) # Up, Down, Sideways
|
||||
nn.Linear(256, 2) # [direction, confidence]
|
||||
)
|
||||
|
||||
self.price_pred_midterm = nn.Sequential(
|
||||
nn.Linear(1024, 1024), # Increased from 512
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512), # Increased from 256
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), # Increased from 128
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
self.price_pred_longterm = nn.Sequential(
|
||||
nn.Linear(1024, 1024), # Increased from 512
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512), # Increased from 256
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), # Increased from 128
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3) # Up, Down, Sideways
|
||||
)
|
||||
# Direction activation (tanh for -1 to 1)
|
||||
self.direction_activation = nn.Tanh()
|
||||
# Confidence activation (sigmoid for 0 to 1)
|
||||
self.confidence_activation = nn.Sigmoid()
|
||||
|
||||
# ULTRA MASSIVE value prediction with ensemble approaches
|
||||
self.price_pred_value = nn.Sequential(
|
||||
@ -490,10 +472,14 @@ class EnhancedCNN(nn.Module):
|
||||
# Extrema predictions (bottom/top/neither detection)
|
||||
extrema_pred = self.extrema_head(features_refined)
|
||||
|
||||
# Multi-timeframe price movement predictions
|
||||
price_immediate = self.price_pred_immediate(features_refined)
|
||||
price_midterm = self.price_pred_midterm(features_refined)
|
||||
price_longterm = self.price_pred_longterm(features_refined)
|
||||
# Price direction predictions
|
||||
price_direction_raw = self.price_direction_head(features_refined)
|
||||
|
||||
# Apply separate activations to direction and confidence
|
||||
direction = self.direction_activation(price_direction_raw[:, 0:1]) # -1 to 1
|
||||
confidence = self.confidence_activation(price_direction_raw[:, 1:2]) # 0 to 1
|
||||
price_direction_pred = torch.cat([direction, confidence], dim=1) # [batch, 2]
|
||||
|
||||
price_values = self.price_pred_value(features_refined)
|
||||
|
||||
# Additional specialized predictions for enhanced accuracy
|
||||
@ -502,15 +488,14 @@ class EnhancedCNN(nn.Module):
|
||||
market_regime_pred = self.market_regime_head(features_refined)
|
||||
risk_pred = self.risk_head(features_refined)
|
||||
|
||||
# Package all price predictions into a single tensor (use immediate as primary)
|
||||
# For compatibility with DQN agent, we return price_immediate as the price prediction tensor
|
||||
price_pred_tensor = price_immediate
|
||||
# Use the price direction prediction directly (already [batch, 2])
|
||||
price_direction_tensor = price_direction_pred
|
||||
|
||||
# Package additional predictions into a single tensor (use volatility as primary)
|
||||
# For compatibility with DQN agent, we return volatility_pred as the advanced prediction tensor
|
||||
advanced_pred_tensor = volatility_pred
|
||||
|
||||
return q_values, extrema_pred, price_pred_tensor, features_refined, advanced_pred_tensor
|
||||
return q_values, extrema_pred, price_direction_tensor, features_refined, advanced_pred_tensor
|
||||
|
||||
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
@ -528,7 +513,11 @@ class EnhancedCNN(nn.Module):
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||
q_values, extrema_pred, price_direction_predictions, features, advanced_predictions = self(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_predictions is not None:
|
||||
self.process_price_direction_predictions(price_direction_predictions)
|
||||
|
||||
# Apply softmax to get action probabilities
|
||||
action_probs_tensor = torch.softmax(q_values, dim=1)
|
||||
@ -565,6 +554,100 @@ class EnhancedCNN(nn.Module):
|
||||
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
|
||||
|
||||
return action_idx, confidence, action_probs
|
||||
|
||||
def process_price_direction_predictions(self, price_direction_pred: torch.Tensor) -> Dict[str, float]:
|
||||
"""
|
||||
Process price direction predictions and convert to standardized format
|
||||
|
||||
Args:
|
||||
price_direction_pred: Tensor of shape (batch_size, 2) containing [direction, confidence]
|
||||
|
||||
Returns:
|
||||
Dict with direction (-1 to 1) and confidence (0 to 1)
|
||||
"""
|
||||
try:
|
||||
if price_direction_pred is None or price_direction_pred.numel() == 0:
|
||||
return {}
|
||||
|
||||
# Extract direction and confidence values
|
||||
direction_value = float(price_direction_pred[0, 0].item()) # -1 to 1
|
||||
confidence_value = float(price_direction_pred[0, 1].item()) # 0 to 1
|
||||
|
||||
processed_directions = {
|
||||
'direction': direction_value,
|
||||
'confidence': confidence_value
|
||||
}
|
||||
|
||||
# Store for later access
|
||||
self.last_price_direction = processed_directions
|
||||
|
||||
return processed_directions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing price direction predictions: {e}")
|
||||
return {}
|
||||
|
||||
def get_price_direction_vector(self) -> Dict[str, float]:
|
||||
"""
|
||||
Get the current price direction and confidence
|
||||
|
||||
Returns:
|
||||
Dict with direction (-1 to 1) and confidence (0 to 1)
|
||||
"""
|
||||
return getattr(self, 'last_price_direction', {})
|
||||
|
||||
def get_price_direction_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a summary of price direction prediction
|
||||
|
||||
Returns:
|
||||
Dict containing direction and confidence information
|
||||
"""
|
||||
try:
|
||||
last_direction = getattr(self, 'last_price_direction', {})
|
||||
if not last_direction:
|
||||
return {
|
||||
'direction_value': 0.0,
|
||||
'confidence_value': 0.0,
|
||||
'direction_label': "SIDEWAYS",
|
||||
'discrete_direction': 0,
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
direction_value = last_direction['direction']
|
||||
confidence_value = last_direction['confidence']
|
||||
|
||||
# Convert to discrete direction
|
||||
if direction_value > 0.1:
|
||||
direction_label = "UP"
|
||||
discrete_direction = 1
|
||||
elif direction_value < -0.1:
|
||||
direction_label = "DOWN"
|
||||
discrete_direction = -1
|
||||
else:
|
||||
direction_label = "SIDEWAYS"
|
||||
discrete_direction = 0
|
||||
|
||||
return {
|
||||
'direction_value': float(direction_value),
|
||||
'confidence_value': float(confidence_value),
|
||||
'direction_label': direction_label,
|
||||
'discrete_direction': discrete_direction,
|
||||
'strength': abs(float(direction_value)),
|
||||
'weighted_strength': abs(float(direction_value)) * float(confidence_value)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating price direction summary: {e}")
|
||||
return {
|
||||
'direction_value': 0.0,
|
||||
'confidence_value': 0.0,
|
||||
'direction_label': "SIDEWAYS",
|
||||
'discrete_direction': 0,
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
Reference in New Issue
Block a user