misc
This commit is contained in:
@ -78,17 +78,25 @@ class CNNPyTorch(nn.Module):
|
||||
window_size, num_features = input_shape
|
||||
self.window_size = window_size
|
||||
|
||||
# Simpler architecture with fewer layers and dropout
|
||||
# Increased complexity
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.Conv1d(num_features, 64, kernel_size=3, padding=1), # Increased filters
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.Conv1d(64, 128, kernel_size=3, padding=1), # Increased filters
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# Added third conv layer
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv1d(128, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
@ -96,12 +104,12 @@ class CNNPyTorch(nn.Module):
|
||||
# Global average pooling to handle variable length sequences
|
||||
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
# Fully connected layers
|
||||
# Fully connected layers (updated input size and hidden size)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.Linear(128, 64), # Updated input size from conv3, increased hidden size
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(32, output_size)
|
||||
nn.Linear(64, output_size)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -120,10 +128,11 @@ class CNNPyTorch(nn.Module):
|
||||
# Convolutional layers
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x) # Added conv3 pass
|
||||
|
||||
# Global pooling
|
||||
x = self.global_pool(x)
|
||||
x = x.squeeze(-1)
|
||||
x = x.squeeze(-1) # Shape becomes [batch, 128]
|
||||
|
||||
# Fully connected layers
|
||||
action_logits = self.fc(x)
|
||||
@ -216,6 +225,8 @@ class CNNModelPyTorch:
|
||||
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
|
||||
|
||||
def train_epoch(self, X_train, y_train, future_prices, batch_size):
|
||||
# Add a call to predict_extrema here
|
||||
self.predict_extrema(X_train)
|
||||
"""Train the model for one epoch with focus on short-term pattern recognition"""
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
@ -321,7 +332,8 @@ class CNNModelPyTorch:
|
||||
|
||||
return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
|
||||
|
||||
def predict(self, X):
|
||||
def predict_extrema(self, X):
|
||||
# Predict local extrema (lows and highs) based on input data
|
||||
"""Make predictions optimized for short-term high-leverage trading signals"""
|
||||
self.model.eval()
|
||||
|
||||
|
Reference in New Issue
Block a user