cleanup models; beef up models to 500M
This commit is contained in:
@ -13,15 +13,13 @@ import torch.nn.functional as F
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from NN.models.simple_cnn import CNNModelPyTorch
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DQNAgent:
|
||||
"""
|
||||
Deep Q-Network agent for trading
|
||||
Uses CNN model as the base network with GPU support
|
||||
Uses Enhanced CNN model as the base network with GPU support for improved performance
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
@ -59,23 +57,18 @@ class DQNAgent:
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
|
||||
# Set device for computation (default to CPU)
|
||||
# Set device for computation (default to GPU if available)
|
||||
if device is None:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Initialize models with appropriate architecture based on state shape
|
||||
if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1:
|
||||
# For image-like states (from RL environment with CNN)
|
||||
from NN.models.simple_cnn import SimpleCNN
|
||||
self.policy_net = SimpleCNN(self.state_dim, self.n_actions)
|
||||
self.target_net = SimpleCNN(self.state_dim, self.n_actions)
|
||||
else:
|
||||
# For 1D state vectors (most environments)
|
||||
from NN.models.simple_mlp import SimpleMLP
|
||||
self.policy_net = SimpleMLP(self.state_dim, self.n_actions)
|
||||
self.target_net = SimpleMLP(self.state_dim, self.n_actions)
|
||||
# Initialize models with Enhanced CNN architecture for better performance
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
# Use Enhanced CNN for both policy and target networks
|
||||
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
|
||||
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
|
||||
|
||||
# Initialize the target network with the same weights as the policy network
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
@ -166,11 +159,15 @@ class DQNAgent:
|
||||
self.state_size = np.prod(state_shape)
|
||||
self.action_size = n_actions
|
||||
self.memory_size = buffer_size
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0]] # Default timeframes
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
|
||||
|
||||
logger.info(f"DQN Agent using device: {self.device}")
|
||||
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
|
||||
# Log model parameters
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
if device is not None:
|
||||
@ -300,7 +297,7 @@ class DQNAgent:
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred, price_predictions, hidden_features = self.policy_net(state_tensor)
|
||||
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Store hidden features for integration
|
||||
@ -650,12 +647,12 @@ class DQNAgent:
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values with target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch between rewards and next_q_values
|
||||
@ -727,12 +724,12 @@ class DQNAgent:
|
||||
# Forward pass with amp autocasting
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
|
Reference in New Issue
Block a user