import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os import logging # Configure logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SimpleMLP(nn.Module): """ Simple Multi-Layer Perceptron for reinforcement learning with vector state inputs Implements dueling architecture for better Q-learning """ def __init__(self, state_dim, n_actions): super(SimpleMLP, self).__init__() # Store dimensions self.state_dim = state_dim self.n_actions = n_actions # Calculate input size if isinstance(state_dim, tuple): self.input_size = int(np.prod(state_dim)) else: self.input_size = state_dim # Hidden layers self.fc1 = nn.Linear(self.input_size, 256) self.fc2 = nn.Linear(256, 256) # Dueling architecture self.advantage = nn.Linear(256, n_actions) self.value = nn.Linear(256, 1) # Extrema detection self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither # Move to appropriate device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.to(self.device) logger.info(f"SimpleMLP initialized with input size: {self.input_size}, actions: {n_actions}") def forward(self, x): """ Forward pass through the network Returns both action values and extrema predictions """ # Handle different input shapes if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1: x = x.view(-1, self.input_size) # Main network x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) # Dueling architecture advantage = self.advantage(x) value = self.value(x) # Combine value and advantage (Q = V + A - mean(A)) q_values = value + advantage - advantage.mean(dim=1, keepdim=True) # Extrema predictions extrema = F.softmax(self.extrema_head(x), dim=1) return q_values, extrema