gogo2/NN/models/simple_mlp.py
2025-04-02 14:03:20 +03:00

70 lines
2.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import logging
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SimpleMLP(nn.Module):
"""
Simple Multi-Layer Perceptron for reinforcement learning with vector state inputs
Implements dueling architecture for better Q-learning
"""
def __init__(self, state_dim, n_actions):
super(SimpleMLP, self).__init__()
# Store dimensions
self.state_dim = state_dim
self.n_actions = n_actions
# Calculate input size
if isinstance(state_dim, tuple):
self.input_size = int(np.prod(state_dim))
else:
self.input_size = state_dim
# Hidden layers
self.fc1 = nn.Linear(self.input_size, 256)
self.fc2 = nn.Linear(256, 256)
# Dueling architecture
self.advantage = nn.Linear(256, n_actions)
self.value = nn.Linear(256, 1)
# Extrema detection
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
# Move to appropriate device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
logger.info(f"SimpleMLP initialized with input size: {self.input_size}, actions: {n_actions}")
def forward(self, x):
"""
Forward pass through the network
Returns both action values and extrema predictions
"""
# Handle different input shapes
if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1:
x = x.view(-1, self.input_size)
# Main network
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
# Dueling architecture
advantage = self.advantage(x)
value = self.value(x)
# Combine value and advantage (Q = V + A - mean(A))
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
# Extrema predictions
extrema = F.softmax(self.extrema_head(x), dim=1)
return q_values, extrema