wip
This commit is contained in:
@ -263,8 +263,12 @@ class CNNModelPyTorch:
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
# Use indices where the masks have True values
|
||||
for i in range(len(retrospective_targets)):
|
||||
if local_max_mask[i]:
|
||||
retrospective_targets[i] = 0 # SELL at local max
|
||||
elif local_min_mask[i]:
|
||||
retrospective_targets[i] = 2 # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
@ -372,8 +376,12 @@ class CNNModelPyTorch:
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
# Use indices where the masks have True values
|
||||
for i in range(len(retrospective_targets)):
|
||||
if local_max_mask[i]:
|
||||
retrospective_targets[i] = 0 # SELL at local max
|
||||
elif local_min_mask[i]:
|
||||
retrospective_targets[i] = 2 # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
@ -403,13 +411,29 @@ class CNNModelPyTorch:
|
||||
def predict(self, X):
|
||||
"""Make predictions on input data"""
|
||||
self.model.eval()
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(X, torch.Tensor):
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_tensor = X.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(X_tensor)
|
||||
# To maintain compatibility with the transformer model, return the action probs
|
||||
# And a dummy price prediction of zeros
|
||||
return outputs.cpu().numpy(), np.zeros((len(X), 1))
|
||||
|
||||
# Get the current close prices from the input
|
||||
current_prices = X_tensor[:, -1, 3].cpu().numpy() # Last timestamp's close price
|
||||
|
||||
# For price predictions, we'll estimate based on the action probabilities
|
||||
# Buy (2) means price likely to go up, Sell (0) means price likely to go down
|
||||
action_probs = outputs.cpu().numpy()
|
||||
price_directions = np.argmax(action_probs, axis=1) - 1 # -1, 0, or 1
|
||||
|
||||
# Simple price prediction: current price + small change based on predicted direction
|
||||
# Use 0.001 (0.1%) as a baseline change
|
||||
price_preds = current_prices * (1 + price_directions * 0.001)
|
||||
|
||||
return action_probs, price_preds.reshape(-1, 1)
|
||||
|
||||
def predict_next_candles(self, X, n_candles=3):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user