initial movel changes to fix performance
This commit is contained in:
@ -74,6 +74,107 @@ class AdaptiveNorm(nn.Module):
|
||||
self.layer_norm_1d = nn.LayerNorm([channels, seq_len]).to(x.device)
|
||||
return self.layer_norm_1d(x)
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
"""
|
||||
Simple CNN model for reinforcement learning with image-like state inputs
|
||||
"""
|
||||
def __init__(self, input_shape, n_actions):
|
||||
super(SimpleCNN, self).__init__()
|
||||
|
||||
# Store dimensions
|
||||
self.input_shape = input_shape
|
||||
self.n_actions = n_actions
|
||||
|
||||
# Calculate input dimensions
|
||||
if len(input_shape) == 3: # [channels, height, width]
|
||||
self.channels, self.height, self.width = input_shape
|
||||
self.feature_dim = self.height * self.width
|
||||
elif len(input_shape) == 2: # [timeframes, features]
|
||||
self.channels = input_shape[0]
|
||||
self.features = input_shape[1]
|
||||
self.feature_dim = self.features
|
||||
elif len(input_shape) == 1: # [features]
|
||||
self.channels = 1
|
||||
self.features = input_shape[0]
|
||||
self.feature_dim = self.features
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {input_shape}")
|
||||
|
||||
# Build network
|
||||
self._build_network()
|
||||
|
||||
# Initialize device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
logger.info(f"SimpleCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the neural network with current feature dimensions"""
|
||||
# Create a flexible architecture that adapts to input dimensions
|
||||
self.fc_layers = nn.Sequential(
|
||||
nn.Linear(self.feature_dim, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 256),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Output heads (Dueling DQN architecture)
|
||||
self.advantage_head = nn.Linear(256, self.n_actions)
|
||||
self.value_head = nn.Linear(256, 1)
|
||||
|
||||
# Extrema detection head
|
||||
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
if features != self.feature_dim:
|
||||
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
|
||||
self.feature_dim = features
|
||||
self._build_network()
|
||||
# Move to device after rebuilding
|
||||
self.to(self.device)
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the network
|
||||
Returns both action values and extrema predictions
|
||||
"""
|
||||
# Handle different input shapes
|
||||
if len(x.shape) == 2: # [batch_size, features]
|
||||
# Simple feature vector
|
||||
batch_size, features = x.shape
|
||||
# Check if we need to rebuild the network for new dimensions
|
||||
self._check_rebuild_network(features)
|
||||
|
||||
elif len(x.shape) == 3: # [batch_size, timeframes/channels, features]
|
||||
# Reshape to flatten timeframes/channels with features
|
||||
batch_size, timeframes, features = x.shape
|
||||
total_features = timeframes * features
|
||||
|
||||
# Check if we need to rebuild the network for new dimensions
|
||||
self._check_rebuild_network(total_features)
|
||||
|
||||
# Reshape tensor to [batch_size, total_features]
|
||||
x = x.reshape(batch_size, total_features)
|
||||
|
||||
# Apply fully connected layers
|
||||
fc_out = self.fc_layers(x)
|
||||
|
||||
# Dueling architecture
|
||||
advantage = self.advantage_head(fc_out)
|
||||
value = self.value_head(fc_out)
|
||||
|
||||
# Q-values = value + (advantage - mean(advantage))
|
||||
action_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
# Extrema predictions
|
||||
extrema_pred = self.extrema_head(fc_out)
|
||||
|
||||
return action_values, extrema_pred
|
||||
|
||||
class CNNModelPyTorch(nn.Module):
|
||||
"""
|
||||
CNN model for trading with multiple timeframes
|
||||
|
Reference in New Issue
Block a user