import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random from typing import Tuple, List import os import sys import logging import torch.nn.functional as F # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from NN.models.simple_cnn import CNNModelPyTorch # Configure logger logger = logging.getLogger(__name__) class DQNAgent: """ Deep Q-Network agent for trading Uses CNN model as the base network with GPU support """ def __init__(self, state_shape: Tuple[int, ...], n_actions: int, learning_rate: float = 0.0005, # Reduced learning rate for more stability gamma: float = 0.97, # Slightly reduced discount factor epsilon: float = 1.0, epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration epsilon_decay: float = 0.9975, # Slower decay rate buffer_size: int = 20000, # Increased memory size batch_size: int = 128, # Larger batch size target_update: int = 5, # More frequent target updates device=None): # Device for computations # Extract state dimensions if isinstance(state_shape, tuple) and len(state_shape) > 1: # Multi-dimensional state (like image or sequence) self.state_dim = state_shape else: # 1D state if isinstance(state_shape, tuple): self.state_dim = state_shape[0] else: self.state_dim = state_shape # Store parameters self.n_actions = n_actions self.learning_rate = learning_rate self.gamma = gamma self.epsilon = epsilon self.epsilon_min = epsilon_min self.epsilon_decay = epsilon_decay self.buffer_size = buffer_size self.batch_size = batch_size self.target_update = target_update # Set device for computation (default to CPU) if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device # Initialize models with appropriate architecture based on state shape if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1: # For image-like states (from RL environment with CNN) from NN.models.simple_cnn import SimpleCNN self.policy_net = SimpleCNN(self.state_dim, self.n_actions) self.target_net = SimpleCNN(self.state_dim, self.n_actions) else: # For 1D state vectors (most environments) from NN.models.simple_mlp import SimpleMLP self.policy_net = SimpleMLP(self.state_dim, self.n_actions) self.target_net = SimpleMLP(self.state_dim, self.n_actions) # Initialize the target network with the same weights as the policy network self.target_net.load_state_dict(self.policy_net.state_dict()) # Set models to eval mode (important for batch norm, dropout) self.target_net.eval() # Optimization components self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate) self.criterion = nn.MSELoss() # Experience replay memory self.memory = [] self.positive_memory = [] # Special memory for storing good experiences self.update_count = 0 # Extrema detection tracking self.last_extrema_pred = { 'class': 2, # Default to "neither" (not extrema) 'confidence': 0.0, 'raw': None } self.extrema_memory = [] # Special memory for storing extrema points # Performance tracking self.losses = [] self.avg_reward = 0.0 self.best_reward = -float('inf') self.no_improvement_count = 0 # Check if mixed precision training should be used self.use_mixed_precision = False if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: self.use_mixed_precision = True self.scaler = torch.cuda.amp.GradScaler() logger.info("Mixed precision training enabled") else: logger.info("Mixed precision training disabled") # Track if we're in training mode self.training = True # For compatibility with old code self.state_size = np.prod(state_shape) self.action_size = n_actions self.memory_size = buffer_size self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0]] # Default timeframes logger.info(f"DQN Agent using device: {self.device}") def move_models_to_device(self, device=None): """Move models to the specified device (GPU/CPU)""" if device is not None: self.device = device try: self.policy_net = self.policy_net.to(self.device) self.target_net = self.target_net.to(self.device) logger.info(f"Moved models to {self.device}") return True except Exception as e: logger.error(f"Failed to move models to {self.device}: {str(e)}") return False def remember(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool, is_extrema: bool = False): """ Store experience in memory with prioritization Args: state: Current state action: Action taken reward: Reward received next_state: Next state done: Whether episode is done is_extrema: Whether this is a local extrema sample (for specialized learning) """ experience = (state, action, reward, next_state, done) # Always add to main memory self.memory.append(experience) # Check if this is an extrema point based on our extrema detection head if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2: # Class 0 = bottom, 1 = top, 2 = neither # Only consider high confidence predictions if self.last_extrema_pred['confidence'] > 0.7: self.extrema_memory.append(experience) # Log this special experience extrema_type = "BOTTOM" if self.last_extrema_pred['class'] == 0 else "TOP" logger.info(f"Stored {extrema_type} experience with reward {reward:.4f}") # For tops and bottoms, also duplicate the experience in memory to learn more from it for _ in range(2): # Add 2 extra copies self.memory.append(experience) # Explicitly marked extrema points also go to extrema memory elif is_extrema: self.extrema_memory.append(experience) # Store positive experiences separately for prioritized replay if reward > 0: self.positive_memory.append(experience) # For very good rewards, duplicate to learn more from them if reward > 0.1: for _ in range(min(int(reward * 10), 5)): # Cap at 5 extra copies for very high rewards self.positive_memory.append(experience) # Keep memory size under control if len(self.memory) > self.buffer_size: # Keep more recent experiences self.memory = self.memory[-self.buffer_size:] # Keep specialized memories under control too if len(self.positive_memory) > self.buffer_size // 4: self.positive_memory = self.positive_memory[-(self.buffer_size // 4):] if len(self.extrema_memory) > self.buffer_size // 4: self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):] def act(self, state: np.ndarray, explore=True) -> int: """Choose action using epsilon-greedy policy with explore flag""" if explore and random.random() < self.epsilon: return random.randrange(self.n_actions) with torch.no_grad(): # Ensure state is normalized before inference state_tensor = self._normalize_state(state) state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device) # Get predictions using the policy network self.policy_net.eval() # Set to evaluation mode for inference action_probs, extrema_pred = self.policy_net(state_tensor) self.policy_net.train() # Back to training mode # Get the predicted extrema class (0=bottom, 1=top, 2=neither) extrema_class = extrema_pred.argmax(dim=1).item() extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item() # Log extrema prediction for significant signals if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER" logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}") # Store extrema prediction for the environment to use self.last_extrema_pred = { 'class': extrema_class, 'confidence': extrema_confidence, 'raw': extrema_pred.cpu().numpy() } # Get the action with highest Q-value action = action_probs.argmax().item() # Adjust action based on extrema prediction (with some probability) if extrema_confidence > 0.8: # Only adjust for strong signals if extrema_class == 0: # Bottom detected # Bias toward BUY at bottoms if action != 0 and random.random() < 0.3 * extrema_confidence: logger.info(f"Adjusting action to BUY based on bottom detection") action = 0 # BUY elif extrema_class == 1: # Top detected # Bias toward SELL at tops if action != 1 and random.random() < 0.3 * extrema_confidence: logger.info(f"Adjusting action to SELL based on top detection") action = 1 # SELL return action def replay(self, use_prioritized=True) -> float: """Experience replay - learn from stored experiences Args: use_prioritized: Whether to use prioritized experience replay Returns: float: Training loss """ # Check if we have enough samples if len(self.memory) < self.batch_size: return 0.0 # Check if mixed precision should be disabled if 'DISABLE_MIXED_PRECISION' in os.environ: self.use_mixed_precision = False # Sample from memory with or without prioritization if use_prioritized and len(self.positive_memory) > self.batch_size // 4: # Use prioritized sampling: mix normal samples with positive reward samples positive_batch_size = min(self.batch_size // 4, len(self.positive_memory)) regular_batch_size = self.batch_size - positive_batch_size # Get positive examples positive_batch = random.sample(self.positive_memory, positive_batch_size) # Get regular examples regular_batch = random.sample(self.memory, regular_batch_size) # Combine batches minibatch = positive_batch + regular_batch else: # Use regular uniform sampling minibatch = random.sample(self.memory, self.batch_size) # Extract batches with proper tensor conversion states = np.vstack([self._normalize_state(x[0]) for x in minibatch]) actions = np.array([x[1] for x in minibatch]) rewards = np.array([x[2] for x in minibatch]) next_states = np.vstack([self._normalize_state(x[3]) for x in minibatch]) dones = np.array([x[4] for x in minibatch], dtype=np.float32) # Convert to torch tensors and move to device states_tensor = torch.FloatTensor(states).to(self.device) actions_tensor = torch.LongTensor(actions).to(self.device) rewards_tensor = torch.FloatTensor(rewards).to(self.device) next_states_tensor = torch.FloatTensor(next_states).to(self.device) dones_tensor = torch.FloatTensor(dones).to(self.device) # First training step with mixed precision if available if self.use_mixed_precision: loss = self._replay_mixed_precision( states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor ) else: loss = self._replay_standard( states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor ) # Occasionally train specifically on extrema points, if we have enough if hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2: if random.random() < 0.3: # 30% chance to do extra extrema training # Sample from extrema memory extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory)) extrema_batch = random.sample(self.extrema_memory, extrema_batch_size) # Extract batches with proper tensor conversion extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch]) extrema_actions = np.array([x[1] for x in extrema_batch]) extrema_rewards = np.array([x[2] for x in extrema_batch]) extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch]) extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32) # Convert to torch tensors and move to device extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device) extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device) extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device) extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device) extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device) # Additional training step focused on extrema points (with smaller learning rate) original_lr = self.optimizer.param_groups[0]['lr'] # Temporarily reduce learning rate for fine-tuning on extrema for param_group in self.optimizer.param_groups: param_group['lr'] = original_lr * 0.5 # Train on extrema if self.use_mixed_precision: extrema_loss = self._replay_mixed_precision( extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, extrema_next_states_tensor, extrema_dones_tensor ) else: extrema_loss = self._replay_standard( extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, extrema_next_states_tensor, extrema_dones_tensor ) # Restore original learning rate for param_group in self.optimizer.param_groups: param_group['lr'] = original_lr logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}") # Average the loss loss = (loss + extrema_loss) / 2 # Store and return loss self.losses.append(loss) return loss def _replay_standard(self, states, actions, rewards, next_states, dones): """Standard precision training step""" # Zero gradients self.optimizer.zero_grad() # Get current Q values and extrema predictions current_q_values, current_extrema_pred = self.policy_net(states) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Get next Q values from target network with torch.no_grad(): next_q_values, next_extrema_pred = self.target_net(next_states) next_q_values = next_q_values.max(1)[0] # Check for dimension mismatch and fix it if rewards.shape[0] != next_q_values.shape[0]: # Log the shape mismatch for debugging logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}") # Use the smaller size to prevent index errors min_size = min(rewards.shape[0], next_q_values.shape[0]) rewards = rewards[:min_size] dones = dones[:min_size] next_q_values = next_q_values[:min_size] current_q_values = current_q_values[:min_size] target_q_values = rewards + (1 - dones) * self.gamma * next_q_values # Compute Q-value loss (primary task) q_loss = nn.MSELoss()(current_q_values, target_q_values) # Create extrema labels from price movements (crude approximation) # If the next state price is higher than current, we might be in an uptrend (not a bottom) # If the next state price is lower than current, we might be in a downtrend (not a top) # This is a simplified approximation; in real scenarios we'd want to use actual extrema detection # Try to extract price from current and next states # Assuming price is in the last feature try: # Extract price feature from sequence data (if available) if len(states.shape) == 3: # [batch, seq, features] current_prices = states[:, -1, -1] # Last timestep, last feature next_prices = next_states[:, -1, -1] else: # [batch, features] current_prices = states[:, -1] # Last feature next_prices = next_states[:, -1] # Compute price changes price_changes = (next_prices - current_prices) / current_prices # Create crude extrema labels: # 0 = bottom: Large negative price change followed by positive change # 1 = top: Large positive price change followed by negative change # 2 = neither: Small or inconsistent changes # Classify based on price change magnitude extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither # Identify potential bottoms (significant negative change) bottoms = (price_changes < -0.003) extrema_labels[bottoms] = 0 # Identify potential tops (significant positive change) tops = (price_changes > 0.003) extrema_labels[tops] = 1 # Calculate extrema prediction loss (auxiliary task) if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: current_extrema_pred = current_extrema_pred[:min_size] extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) # Combined loss (primary + auxiliary with lower weight) # Typically auxiliary tasks should have lower weight to not dominate the primary task loss = q_loss + 0.3 * extrema_loss # Log separate loss components occasionally if random.random() < 0.01: # Log 1% of the time to avoid flood logger.info(f"Training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}") else: # Fall back to just Q-value loss if extrema predictions aren't available loss = q_loss except Exception as e: # Fallback if price extraction fails logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.") loss = q_loss # Backward pass and optimize loss.backward() # Gradient clipping to prevent exploding gradients torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.optimizer.step() # Update target network if needed self.update_count += 1 if self.update_count % self.target_update == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) # Track and decay epsilon self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) return loss.item() def _replay_mixed_precision(self, states, actions, rewards, next_states, dones): """Mixed precision training step for better GPU performance""" # Check if mixed precision should be explicitly disabled if 'DISABLE_MIXED_PRECISION' in os.environ: logger.info("Mixed precision explicitly disabled by environment variable") return self._replay_standard(states, actions, rewards, next_states, dones) try: # Zero gradients self.optimizer.zero_grad() # Forward pass with amp autocasting with torch.cuda.amp.autocast(): # Get current Q values and extrema predictions current_q_values, current_extrema_pred = self.policy_net(states) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Get next Q values from target network with torch.no_grad(): next_q_values, next_extrema_pred = self.target_net(next_states) next_q_values = next_q_values.max(1)[0] # Check for dimension mismatch and fix it if rewards.shape[0] != next_q_values.shape[0]: # Log the shape mismatch for debugging logger.warning(f"Shape mismatch detected: rewards {rewards.shape}, next_q_values {next_q_values.shape}") # Use the smaller size to prevent index errors min_size = min(rewards.shape[0], next_q_values.shape[0]) rewards = rewards[:min_size] dones = dones[:min_size] next_q_values = next_q_values[:min_size] current_q_values = current_q_values[:min_size] target_q_values = rewards + (1 - dones) * self.gamma * next_q_values # Compute Q-value loss (primary task) q_loss = nn.MSELoss()(current_q_values, target_q_values) # Create extrema labels from price movements (crude approximation) # Try to extract price from current and next states try: # Extract price feature from sequence data (if available) if len(states.shape) == 3: # [batch, seq, features] current_prices = states[:, -1, -1] # Last timestep, last feature next_prices = next_states[:, -1, -1] else: # [batch, features] current_prices = states[:, -1] # Last feature next_prices = next_states[:, -1] # Compute price changes price_changes = (next_prices - current_prices) / current_prices # Create crude extrema labels: # 0 = bottom: Large negative price change followed by positive change # 1 = top: Large positive price change followed by negative change # 2 = neither: Small or inconsistent changes # Classify based on price change magnitude extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither # Identify potential bottoms (significant negative change) bottoms = (price_changes < -0.003) extrema_labels[bottoms] = 0 # Identify potential tops (significant positive change) tops = (price_changes > 0.003) extrema_labels[tops] = 1 # Calculate extrema prediction loss (auxiliary task) if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: current_extrema_pred = current_extrema_pred[:min_size] extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) # Combined loss (primary + auxiliary with lower weight) loss = q_loss + 0.3 * extrema_loss # Log separate loss components occasionally if random.random() < 0.01: # Log 1% of the time to avoid flood logger.info(f"Mixed precision training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}") else: # Fall back to just Q-value loss loss = q_loss except Exception as e: # Fallback if price extraction fails logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.") loss = q_loss # Backward pass with scaled gradients self.scaler.scale(loss).backward() # Gradient clipping on scaled gradients self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) # Update with scaler self.scaler.step(self.optimizer) self.scaler.update() # Update target network if needed self.update_count += 1 if self.update_count % self.target_update == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) # Track and decay epsilon self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) return loss.item() except Exception as e: logger.error(f"Error in mixed precision training: {str(e)}") logger.warning("Falling back to standard precision training") # Fall back to standard training return self._replay_standard(states, actions, rewards, next_states, dones) def train_on_extrema(self, states, actions, rewards, next_states, dones): """ Special training function specifically for extrema points Args: states: Batch of states at extrema points actions: Batch of actions rewards: Batch of rewards next_states: Batch of next states dones: Batch of done flags Returns: float: Training loss """ # Convert to numpy arrays if not already if not isinstance(states, np.ndarray): states = np.array(states) if not isinstance(actions, np.ndarray): actions = np.array(actions) if not isinstance(rewards, np.ndarray): rewards = np.array(rewards) if not isinstance(next_states, np.ndarray): next_states = np.array(next_states) if not isinstance(dones, np.ndarray): dones = np.array(dones, dtype=np.float32) # Normalize states states = np.vstack([self._normalize_state(s) for s in states]) next_states = np.vstack([self._normalize_state(s) for s in next_states]) # Convert to torch tensors and move to device states_tensor = torch.FloatTensor(states).to(self.device) actions_tensor = torch.LongTensor(actions).to(self.device) rewards_tensor = torch.FloatTensor(rewards).to(self.device) next_states_tensor = torch.FloatTensor(next_states).to(self.device) dones_tensor = torch.FloatTensor(dones).to(self.device) # Choose training method based on precision mode if self.use_mixed_precision: return self._replay_mixed_precision( states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor ) else: return self._replay_standard( states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor ) def _normalize_state(self, state: np.ndarray) -> np.ndarray: """Normalize the state data to prevent numerical issues""" # Handle NaN and infinite values state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0) # Check if state is 1D array (happens in some environments) if len(state.shape) == 1: # If 1D, we need to normalize the whole array normalized_state = state.copy() # Convert any timestamp or non-numeric data to float for i in range(len(normalized_state)): # Check for timestamp-like objects if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')): # Convert timestamp to float (seconds since epoch) normalized_state[i] = float(normalized_state[i].timestamp()) elif not isinstance(normalized_state[i], (int, float, np.number)): # Set non-numeric data to 0 normalized_state[i] = 0.0 # Ensure all values are float normalized_state = normalized_state.astype(np.float32) # Simple min-max normalization for 1D state state_min = np.min(normalized_state) state_max = np.max(normalized_state) if state_max > state_min: normalized_state = (normalized_state - state_min) / (state_max - state_min) return normalized_state # Handle 2D arrays normalized_state = np.zeros_like(state, dtype=np.float32) # Convert any timestamp or non-numeric data to float for i in range(state.shape[0]): for j in range(state.shape[1]): if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')): # Convert timestamp to float (seconds since epoch) normalized_state[i, j] = float(state[i, j].timestamp()) elif isinstance(state[i, j], (int, float, np.number)): normalized_state[i, j] = state[i, j] else: # Set non-numeric data to 0 normalized_state[i, j] = 0.0 # Loop through each timeframe's features in the combined state feature_count = state.shape[1] // len(self.timeframes) for tf_idx in range(len(self.timeframes)): start_idx = tf_idx * feature_count end_idx = start_idx + feature_count # Extract this timeframe's features tf_features = normalized_state[:, start_idx:end_idx] # Normalize OHLCV data by the first close price in the window # This makes price movements relative rather than absolute price_idx = 3 # Assuming close price is at index 3 if price_idx < tf_features.shape[1]: reference_price = np.mean(tf_features[:, price_idx]) if reference_price != 0: # Normalize price-related columns (OHLC) for i in range(4): # First 4 columns are OHLC if i < tf_features.shape[1]: normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price # Normalize volume using mean and std vol_idx = 4 # Assuming volume is at index 4 if vol_idx < tf_features.shape[1]: vol_mean = np.mean(tf_features[:, vol_idx]) vol_std = np.std(tf_features[:, vol_idx]) if vol_std > 0: normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std else: normalized_state[:, start_idx + vol_idx] = 0 # Other features (technical indicators) - normalize with min-max scaling for i in range(5, feature_count): if i < tf_features.shape[1]: feature_min = np.min(tf_features[:, i]) feature_max = np.max(tf_features[:, i]) if feature_max > feature_min: normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min) else: normalized_state[:, start_idx + i] = 0 return normalized_state def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01): """Update learning metrics and perform learning rate adjustments if needed""" # Update average reward with exponential moving average if self.avg_reward == 0: self.avg_reward = episode_reward else: self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward # Check if we're making sufficient progress if episode_reward > (1 + best_reward_threshold) * self.best_reward: self.best_reward = episode_reward self.no_improvement_count = 0 return True # Improved else: self.no_improvement_count += 1 # If no improvement for a while, adjust learning rate if self.no_improvement_count >= 10: current_lr = self.optimizer.param_groups[0]['lr'] new_lr = current_lr * 0.5 if new_lr >= 1e-6: # Don't reduce below minimum threshold for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr logger.info(f"Reducing learning rate from {current_lr} to {new_lr}") self.no_improvement_count = 0 return False # No improvement def save(self, path: str): """Save model and agent state""" os.makedirs(os.path.dirname(path), exist_ok=True) # Save policy network self.policy_net.save(f"{path}_policy") # Save target network self.target_net.save(f"{path}_target") # Save agent state state = { 'epsilon': self.epsilon, 'update_count': self.update_count, 'losses': self.losses, 'optimizer_state': self.optimizer.state_dict(), 'best_reward': self.best_reward, 'avg_reward': self.avg_reward } torch.save(state, f"{path}_agent_state.pt") logger.info(f"Agent state saved to {path}_agent_state.pt") def load(self, path: str): """Load model and agent state""" # Load policy network self.policy_net.load(f"{path}_policy") # Load target network self.target_net.load(f"{path}_target") # Load agent state try: agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device) self.epsilon = agent_state['epsilon'] self.update_count = agent_state['update_count'] self.losses = agent_state['losses'] self.optimizer.load_state_dict(agent_state['optimizer_state']) # Load additional metrics if they exist if 'best_reward' in agent_state: self.best_reward = agent_state['best_reward'] if 'avg_reward' in agent_state: self.avg_reward = agent_state['avg_reward'] logger.info(f"Agent state loaded from {path}_agent_state.pt") except FileNotFoundError: logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")