import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random from typing import Tuple, List import os import sys import logging import torch.nn.functional as F # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # Import the EnhancedCNN model from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset # Configure logger logger = logging.getLogger(__name__) class EnhancedDQNAgent: """ Enhanced Deep Q-Network agent for trading Uses the improved EnhancedCNN model with residual connections and attention mechanisms """ def __init__(self, state_shape: Tuple[int, ...], n_actions: int, learning_rate: float = 0.0003, # Slightly reduced learning rate for stability gamma: float = 0.95, # Discount factor epsilon: float = 1.0, epsilon_min: float = 0.05, epsilon_decay: float = 0.995, # Slower decay for more exploration buffer_size: int = 50000, # Larger memory buffer batch_size: int = 128, # Larger batch size target_update: int = 10, # More frequent target updates confidence_threshold: float = 0.4, # Lower confidence threshold device=None): # Extract state dimensions if isinstance(state_shape, tuple) and len(state_shape) > 1: # Multi-dimensional state (like image or sequence) self.state_dim = state_shape else: # 1D state if isinstance(state_shape, tuple): self.state_dim = state_shape[0] else: self.state_dim = state_shape # Store parameters self.n_actions = n_actions self.learning_rate = learning_rate self.gamma = gamma self.epsilon = epsilon self.epsilon_min = epsilon_min self.epsilon_decay = epsilon_decay self.buffer_size = buffer_size self.batch_size = batch_size self.target_update = target_update self.confidence_threshold = confidence_threshold # Set device for computation if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device # Initialize models with the enhanced CNN self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold) self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold) # Initialize the target network with the same weights as the policy network self.target_net.load_state_dict(self.policy_net.state_dict()) # Set models to eval mode (important for batch norm, dropout) self.target_net.eval() # Optimization components self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate) self.criterion = nn.MSELoss() # Experience replay memory with example sifting self.memory = ExampleSiftingDataset(max_examples=buffer_size) self.update_count = 0 # Confidence tracking self.confidence_history = [] self.avg_confidence = 0.0 self.max_confidence = 0.0 self.min_confidence = 1.0 # Performance tracking self.losses = [] self.rewards = [] self.avg_reward = 0.0 # Check if mixed precision training should be used self.use_mixed_precision = False if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: self.use_mixed_precision = True self.scaler = torch.cuda.amp.GradScaler() logger.info("Mixed precision training enabled") else: logger.info("Mixed precision training disabled") # For compatibility with old code self.action_size = n_actions logger.info(f"Enhanced DQN Agent using device: {self.device}") logger.info(f"Confidence threshold set to {self.confidence_threshold}") def move_models_to_device(self, device=None): """Move models to the specified device (GPU/CPU)""" if device is not None: self.device = device try: self.policy_net = self.policy_net.to(self.device) self.target_net = self.target_net.to(self.device) logger.info(f"Moved models to {self.device}") return True except Exception as e: logger.error(f"Failed to move models to {self.device}: {str(e)}") return False def _normalize_state(self, state): """Normalize state for better training stability""" try: # Convert to numpy array if needed if isinstance(state, list): state = np.array(state, dtype=np.float32) # Apply normalization based on state shape if len(state.shape) > 1: # Multi-dimensional state - normalize each feature dimension separately for i in range(state.shape[0]): # Skip if all zeros (to avoid division by zero) if np.sum(np.abs(state[i])) > 0: # Standardize each feature dimension mean = np.mean(state[i]) std = np.std(state[i]) if std > 0: state[i] = (state[i] - mean) / std else: # 1D state vector # Skip if all zeros if np.sum(np.abs(state)) > 0: mean = np.mean(state) std = np.std(state) if std > 0: state = (state - mean) / std return state except Exception as e: logger.warning(f"Error normalizing state: {str(e)}") return state def remember(self, state, action, reward, next_state, done): """Store experience in memory with example sifting""" self.memory.add_example(state, action, reward, next_state, done) # Also track rewards for monitoring self.rewards.append(reward) if len(self.rewards) > 100: self.rewards = self.rewards[-100:] self.avg_reward = np.mean(self.rewards) def act(self, state, explore=True): """Choose action using epsilon-greedy policy with built-in confidence thresholding""" if explore and random.random() < self.epsilon: return random.randrange(self.n_actions), 0.0 # Return action and zero confidence # Normalize state before inference normalized_state = self._normalize_state(state) # Use the EnhancedCNN's act method which includes confidence thresholding action, confidence = self.policy_net.act(normalized_state, explore=explore) # Track confidence metrics self.confidence_history.append(confidence) if len(self.confidence_history) > 100: self.confidence_history = self.confidence_history[-100:] # Update confidence metrics self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history) self.max_confidence = max(self.max_confidence, confidence) self.min_confidence = min(self.min_confidence, confidence) # Log average confidence occasionally if random.random() < 0.01: # 1% of the time logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " + f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}") return action, confidence def replay(self): """Train the model using experience replay with high-quality examples""" # Check if enough samples in memory if len(self.memory) < self.batch_size: return 0.0 # Get batch of experiences batch = self.memory.get_batch(self.batch_size) if batch is None: return 0.0 states = torch.FloatTensor(batch['states']).to(self.device) actions = torch.LongTensor(batch['actions']).to(self.device) rewards = torch.FloatTensor(batch['rewards']).to(self.device) next_states = torch.FloatTensor(batch['next_states']).to(self.device) dones = torch.FloatTensor(batch['dones']).to(self.device) # Compute Q values self.policy_net.train() # Set to training mode # Get current Q values if self.use_mixed_precision: with torch.cuda.amp.autocast(): # Get current Q values q_values, _, _, _ = self.policy_net(states) current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Compute target Q values with torch.no_grad(): self.target_net.eval() next_q_values, _, _, _ = self.target_net(next_states) next_q = next_q_values.max(1)[0] target_q = rewards + (1 - dones) * self.gamma * next_q # Compute loss loss = self.criterion(current_q, target_q) # Perform backpropagation with mixed precision self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.scaler.step(self.optimizer) self.scaler.update() else: # Standard precision training # Get current Q values q_values, _, _, _ = self.policy_net(states) current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Compute target Q values with torch.no_grad(): self.target_net.eval() next_q_values, _, _, _ = self.target_net(next_states) next_q = next_q_values.max(1)[0] target_q = rewards + (1 - dones) * self.gamma * next_q # Compute loss loss = self.criterion(current_q, target_q) # Perform backpropagation self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.optimizer.step() # Track loss loss_value = loss.item() self.losses.append(loss_value) if len(self.losses) > 100: self.losses = self.losses[-100:] # Update target network self.update_count += 1 if self.update_count % self.target_update == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) logger.info(f"Updated target network (step {self.update_count})") # Decay epsilon if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay return loss_value def save(self, path): """Save agent state and models""" self.policy_net.save(f"{path}_policy") self.target_net.save(f"{path}_target") # Save agent state torch.save({ 'epsilon': self.epsilon, 'confidence_threshold': self.confidence_threshold, 'losses': self.losses, 'rewards': self.rewards, 'avg_reward': self.avg_reward, 'confidence_history': self.confidence_history, 'avg_confidence': self.avg_confidence, 'max_confidence': self.max_confidence, 'min_confidence': self.min_confidence, 'update_count': self.update_count }, f"{path}_agent_state.pt") logger.info(f"Agent state saved to {path}_agent_state.pt") def load(self, path): """Load agent state and models""" policy_loaded = self.policy_net.load(f"{path}_policy") target_loaded = self.target_net.load(f"{path}_target") # Load agent state if available agent_state_path = f"{path}_agent_state.pt" if os.path.exists(agent_state_path): try: state = torch.load(agent_state_path) self.epsilon = state.get('epsilon', self.epsilon) self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold) self.policy_net.confidence_threshold = self.confidence_threshold self.target_net.confidence_threshold = self.confidence_threshold self.losses = state.get('losses', []) self.rewards = state.get('rewards', []) self.avg_reward = state.get('avg_reward', 0.0) self.confidence_history = state.get('confidence_history', []) self.avg_confidence = state.get('avg_confidence', 0.0) self.max_confidence = state.get('max_confidence', 0.0) self.min_confidence = state.get('min_confidence', 1.0) self.update_count = state.get('update_count', 0) logger.info(f"Agent state loaded from {agent_state_path}") except Exception as e: logger.error(f"Error loading agent state: {str(e)}") return policy_loaded and target_loaded