added more predictions
This commit is contained in:
@ -125,6 +125,14 @@ class SimpleCNN(nn.Module):
|
||||
|
||||
# Extrema detection head
|
||||
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||
|
||||
# Price prediction heads for different timeframes
|
||||
self.price_pred_immediate = nn.Linear(256, 3) # Up, Down, Sideways for immediate term (1s, 1m)
|
||||
self.price_pred_midterm = nn.Linear(256, 3) # Up, Down, Sideways for mid-term (1h)
|
||||
self.price_pred_longterm = nn.Linear(256, 3) # Up, Down, Sideways for long-term (1d)
|
||||
|
||||
# Regression heads for exact price prediction
|
||||
self.price_pred_value = nn.Linear(256, 4) # Predicts % change for each timeframe (1s, 1m, 1h, 1d)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
@ -140,7 +148,7 @@ class SimpleCNN(nn.Module):
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the network
|
||||
Returns both action values and extrema predictions
|
||||
Returns action values, extrema predictions, and price movement predictions for multiple timeframes
|
||||
"""
|
||||
# Handle different input shapes
|
||||
if len(x.shape) == 2: # [batch_size, features]
|
||||
@ -173,7 +181,50 @@ class SimpleCNN(nn.Module):
|
||||
# Extrema predictions
|
||||
extrema_pred = self.extrema_head(fc_out)
|
||||
|
||||
return action_values, extrema_pred
|
||||
# Price movement predictions for different timeframes
|
||||
price_immediate = self.price_pred_immediate(fc_out) # 1s, 1m
|
||||
price_midterm = self.price_pred_midterm(fc_out) # 1h
|
||||
price_longterm = self.price_pred_longterm(fc_out) # 1d
|
||||
|
||||
# Regression values for exact price predictions (percentage changes)
|
||||
price_values = self.price_pred_value(fc_out)
|
||||
|
||||
# Return all predictions in a structured dictionary
|
||||
price_predictions = {
|
||||
'immediate': price_immediate,
|
||||
'midterm': price_midterm,
|
||||
'longterm': price_longterm,
|
||||
'values': price_values
|
||||
}
|
||||
|
||||
return action_values, extrema_pred, price_predictions
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'state_dict': self.state_dict(),
|
||||
'input_shape': self.input_shape,
|
||||
'n_actions': self.n_actions,
|
||||
'feature_dim': self.feature_dim
|
||||
}, f"{path}.pt")
|
||||
logger.info(f"Model saved to {path}.pt")
|
||||
|
||||
def load(self, path):
|
||||
"""Load model weights and architecture"""
|
||||
try:
|
||||
checkpoint = torch.load(f"{path}.pt", map_location=self.device)
|
||||
self.input_shape = checkpoint['input_shape']
|
||||
self.n_actions = checkpoint['n_actions']
|
||||
self.feature_dim = checkpoint['feature_dim']
|
||||
self._build_network()
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.to(self.device)
|
||||
logger.info(f"Model loaded from {path}.pt")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
return False
|
||||
|
||||
class CNNModelPyTorch(nn.Module):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user