added more predictions

This commit is contained in:
Dobromir Popov
2025-04-02 14:20:39 +03:00
parent 70eb7bba9b
commit 7dda00b64a
3 changed files with 967 additions and 143 deletions

View File

@ -125,6 +125,14 @@ class SimpleCNN(nn.Module):
# Extrema detection head
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
# Price prediction heads for different timeframes
self.price_pred_immediate = nn.Linear(256, 3) # Up, Down, Sideways for immediate term (1s, 1m)
self.price_pred_midterm = nn.Linear(256, 3) # Up, Down, Sideways for mid-term (1h)
self.price_pred_longterm = nn.Linear(256, 3) # Up, Down, Sideways for long-term (1d)
# Regression heads for exact price prediction
self.price_pred_value = nn.Linear(256, 4) # Predicts % change for each timeframe (1s, 1m, 1h, 1d)
def _check_rebuild_network(self, features):
"""Check if network needs to be rebuilt for different feature dimensions"""
@ -140,7 +148,7 @@ class SimpleCNN(nn.Module):
def forward(self, x):
"""
Forward pass through the network
Returns both action values and extrema predictions
Returns action values, extrema predictions, and price movement predictions for multiple timeframes
"""
# Handle different input shapes
if len(x.shape) == 2: # [batch_size, features]
@ -173,7 +181,50 @@ class SimpleCNN(nn.Module):
# Extrema predictions
extrema_pred = self.extrema_head(fc_out)
return action_values, extrema_pred
# Price movement predictions for different timeframes
price_immediate = self.price_pred_immediate(fc_out) # 1s, 1m
price_midterm = self.price_pred_midterm(fc_out) # 1h
price_longterm = self.price_pred_longterm(fc_out) # 1d
# Regression values for exact price predictions (percentage changes)
price_values = self.price_pred_value(fc_out)
# Return all predictions in a structured dictionary
price_predictions = {
'immediate': price_immediate,
'midterm': price_midterm,
'longterm': price_longterm,
'values': price_values
}
return action_values, extrema_pred, price_predictions
def save(self, path):
"""Save model weights and architecture"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'state_dict': self.state_dict(),
'input_shape': self.input_shape,
'n_actions': self.n_actions,
'feature_dim': self.feature_dim
}, f"{path}.pt")
logger.info(f"Model saved to {path}.pt")
def load(self, path):
"""Load model weights and architecture"""
try:
checkpoint = torch.load(f"{path}.pt", map_location=self.device)
self.input_shape = checkpoint['input_shape']
self.n_actions = checkpoint['n_actions']
self.feature_dim = checkpoint['feature_dim']
self._build_network()
self.load_state_dict(checkpoint['state_dict'])
self.to(self.device)
logger.info(f"Model loaded from {path}.pt")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
class CNNModelPyTorch(nn.Module):
"""