70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import os
|
|
import logging
|
|
|
|
# Configure logger
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SimpleMLP(nn.Module):
|
|
"""
|
|
Simple Multi-Layer Perceptron for reinforcement learning with vector state inputs
|
|
Implements dueling architecture for better Q-learning
|
|
"""
|
|
def __init__(self, state_dim, n_actions):
|
|
super(SimpleMLP, self).__init__()
|
|
|
|
# Store dimensions
|
|
self.state_dim = state_dim
|
|
self.n_actions = n_actions
|
|
|
|
# Calculate input size
|
|
if isinstance(state_dim, tuple):
|
|
self.input_size = int(np.prod(state_dim))
|
|
else:
|
|
self.input_size = state_dim
|
|
|
|
# Hidden layers
|
|
self.fc1 = nn.Linear(self.input_size, 256)
|
|
self.fc2 = nn.Linear(256, 256)
|
|
|
|
# Dueling architecture
|
|
self.advantage = nn.Linear(256, n_actions)
|
|
self.value = nn.Linear(256, 1)
|
|
|
|
# Extrema detection
|
|
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
|
|
|
# Move to appropriate device
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.to(self.device)
|
|
|
|
logger.info(f"SimpleMLP initialized with input size: {self.input_size}, actions: {n_actions}")
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward pass through the network
|
|
Returns both action values and extrema predictions
|
|
"""
|
|
# Handle different input shapes
|
|
if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1:
|
|
x = x.view(-1, self.input_size)
|
|
|
|
# Main network
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
|
|
# Dueling architecture
|
|
advantage = self.advantage(x)
|
|
value = self.value(x)
|
|
|
|
# Combine value and advantage (Q = V + A - mean(A))
|
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
|
|
|
# Extrema predictions
|
|
extrema = F.softmax(self.extrema_head(x), dim=1)
|
|
|
|
return q_values, extrema |