This commit is contained in:
Dobromir Popov
2025-03-31 02:22:51 +03:00
parent 1b9f471076
commit 8981ad0691
5 changed files with 124 additions and 147 deletions

View File

@ -263,8 +263,12 @@ class CNNModelPyTorch:
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
# Apply masks to set retrospective targets using torch.where
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
# Use indices where the masks have True values
for i in range(len(retrospective_targets)):
if local_max_mask[i]:
retrospective_targets[i] = 0 # SELL at local max
elif local_min_mask[i]:
retrospective_targets[i] = 2 # BUY at local min
# Calculate retrospective loss with higher weight for profitable signals
retrospective_loss = self.criterion(outputs, retrospective_targets)
@ -372,8 +376,12 @@ class CNNModelPyTorch:
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
# Apply masks to set retrospective targets using torch.where
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
# Use indices where the masks have True values
for i in range(len(retrospective_targets)):
if local_max_mask[i]:
retrospective_targets[i] = 0 # SELL at local max
elif local_min_mask[i]:
retrospective_targets[i] = 2 # BUY at local min
# Calculate retrospective loss with higher weight for profitable signals
retrospective_loss = self.criterion(outputs, retrospective_targets)
@ -403,13 +411,29 @@ class CNNModelPyTorch:
def predict(self, X):
"""Make predictions on input data"""
self.model.eval()
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
# Convert to tensor if not already
if not isinstance(X, torch.Tensor):
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
else:
X_tensor = X.to(self.device)
with torch.no_grad():
outputs = self.model(X_tensor)
# To maintain compatibility with the transformer model, return the action probs
# And a dummy price prediction of zeros
return outputs.cpu().numpy(), np.zeros((len(X), 1))
# Get the current close prices from the input
current_prices = X_tensor[:, -1, 3].cpu().numpy() # Last timestamp's close price
# For price predictions, we'll estimate based on the action probabilities
# Buy (2) means price likely to go up, Sell (0) means price likely to go down
action_probs = outputs.cpu().numpy()
price_directions = np.argmax(action_probs, axis=1) - 1 # -1, 0, or 1
# Simple price prediction: current price + small change based on predicted direction
# Use 0.001 (0.1%) as a baseline change
price_preds = current_prices * (1 + price_directions * 0.001)
return action_probs, price_preds.reshape(-1, 1)
def predict_next_candles(self, X, n_candles=3):
"""