cleanup models; beef up models to 500M

This commit is contained in:
Dobromir Popov
2025-05-24 23:22:34 +03:00
parent 01f0a2608f
commit d418f6ce59
10 changed files with 3918 additions and 730 deletions

View File

@ -13,15 +13,13 @@ import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from NN.models.simple_cnn import CNNModelPyTorch
# Configure logger
logger = logging.getLogger(__name__)
class DQNAgent:
"""
Deep Q-Network agent for trading
Uses CNN model as the base network with GPU support
Uses Enhanced CNN model as the base network with GPU support for improved performance
"""
def __init__(self,
state_shape: Tuple[int, ...],
@ -59,23 +57,18 @@ class DQNAgent:
self.batch_size = batch_size
self.target_update = target_update
# Set device for computation (default to CPU)
# Set device for computation (default to GPU if available)
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# Initialize models with appropriate architecture based on state shape
if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1:
# For image-like states (from RL environment with CNN)
from NN.models.simple_cnn import SimpleCNN
self.policy_net = SimpleCNN(self.state_dim, self.n_actions)
self.target_net = SimpleCNN(self.state_dim, self.n_actions)
else:
# For 1D state vectors (most environments)
from NN.models.simple_mlp import SimpleMLP
self.policy_net = SimpleMLP(self.state_dim, self.n_actions)
self.target_net = SimpleMLP(self.state_dim, self.n_actions)
# Initialize models with Enhanced CNN architecture for better performance
from NN.models.enhanced_cnn import EnhancedCNN
# Use Enhanced CNN for both policy and target networks
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
@ -166,11 +159,15 @@ class DQNAgent:
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0]] # Default timeframes
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using device: {self.device}")
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
@ -300,7 +297,7 @@ class DQNAgent:
# Get predictions using the policy network
self.policy_net.eval() # Set to evaluation mode for inference
action_probs, extrema_pred, price_predictions, hidden_features = self.policy_net(state_tensor)
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
self.policy_net.train() # Back to training mode
# Store hidden features for integration
@ -650,12 +647,12 @@ class DQNAgent:
dones = torch.FloatTensor(np.array(dones)).to(self.device)
# Get current Q values
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values with target network
with torch.no_grad():
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch between rewards and next_q_values
@ -727,12 +724,12 @@ class DQNAgent:
# Forward pass with amp autocasting
with torch.cuda.amp.autocast():
# Get current Q values and extrema predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network
with torch.no_grad():
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch and fix it