diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 96488e9..d5a6991 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -99,6 +99,28 @@ class DQNAgent: } self.extrema_memory = [] # Special memory for storing extrema points + # Price prediction tracking + self.last_price_pred = { + 'immediate': { + 'direction': 1, # Default to "sideways" + 'confidence': 0.0, + 'change': 0.0 + }, + 'midterm': { + 'direction': 1, # Default to "sideways" + 'confidence': 0.0, + 'change': 0.0 + }, + 'longterm': { + 'direction': 1, # Default to "sideways" + 'confidence': 0.0, + 'change': 0.0 + } + } + + # Store separate memory for price direction examples + self.price_movement_memory = [] # For storing examples of clear price movements + # Performance tracking self.losses = [] self.avg_reward = 0.0 @@ -157,6 +179,36 @@ class DQNAgent: # Always add to main memory self.memory.append(experience) + # Try to extract price change to analyze the experience + try: + # Extract price feature from sequence data (if available) + if len(state.shape) > 1: # 2D state [timeframes, features] + current_price = state[-1, -1] # Last timeframe, last feature + next_price = next_state[-1, -1] + else: # 1D state + current_price = state[-1] # Last feature + next_price = next_state[-1] + + # Calculate price change + price_change = (next_price - current_price) / current_price + + # Check if this is a significant price movement + if abs(price_change) > 0.002: # Significant price change + # Store in price movement memory + self.price_movement_memory.append(experience) + + # Log significant price movements + direction = "UP" if price_change > 0 else "DOWN" + logger.info(f"Stored significant {direction} price movement: {price_change:.4f}") + + # For clear price movements, also duplicate in main memory to learn more + if abs(price_change) > 0.005: # Very significant movement + for _ in range(2): # Add 2 extra copies + self.memory.append(experience) + except Exception as e: + # Skip price movement analysis if it fails + pass + # Check if this is an extrema point based on our extrema detection head if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2: # Class 0 = bottom, 1 = top, 2 = neither @@ -196,6 +248,9 @@ class DQNAgent: if len(self.extrema_memory) > self.buffer_size // 4: self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):] + + if len(self.price_movement_memory) > self.buffer_size // 4: + self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):] def act(self, state: np.ndarray, explore=True) -> int: """Choose action using epsilon-greedy policy with explore flag""" @@ -209,7 +264,7 @@ class DQNAgent: # Get predictions using the policy network self.policy_net.eval() # Set to evaluation mode for inference - action_probs, extrema_pred = self.policy_net(state_tensor) + action_probs, extrema_pred, price_predictions = self.policy_net(state_tensor) self.policy_net.train() # Back to training mode # Get the predicted extrema class (0=bottom, 1=top, 2=neither) @@ -221,17 +276,81 @@ class DQNAgent: extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER" logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}") - # Store extrema prediction for the environment to use + # Process price predictions + price_immediate = torch.softmax(price_predictions['immediate'], dim=1) + price_midterm = torch.softmax(price_predictions['midterm'], dim=1) + price_longterm = torch.softmax(price_predictions['longterm'], dim=1) + price_values = price_predictions['values'] + + # Get predicted direction for each timeframe (0=down, 1=sideways, 2=up) + immediate_direction = price_immediate.argmax(dim=1).item() + midterm_direction = price_midterm.argmax(dim=1).item() + longterm_direction = price_longterm.argmax(dim=1).item() + + # Get confidence levels + immediate_conf = price_immediate[0, immediate_direction].item() + midterm_conf = price_midterm[0, midterm_direction].item() + longterm_conf = price_longterm[0, longterm_direction].item() + + # Get predicted price change percentages + price_changes = price_values[0].tolist() + + # Log significant price movement predictions + timeframes = ["1s/1m", "1h", "1d", "1w"] + directions = ["DOWN", "SIDEWAYS", "UP"] + + for i, (direction, conf) in enumerate([ + (immediate_direction, immediate_conf), + (midterm_direction, midterm_conf), + (longterm_direction, longterm_conf) + ]): + if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions + logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, " + f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%") + + # Store predictions for environment to use self.last_extrema_pred = { 'class': extrema_class, 'confidence': extrema_confidence, 'raw': extrema_pred.cpu().numpy() } + self.last_price_pred = { + 'immediate': { + 'direction': immediate_direction, + 'confidence': immediate_conf, + 'change': price_changes[0] + }, + 'midterm': { + 'direction': midterm_direction, + 'confidence': midterm_conf, + 'change': price_changes[1] + }, + 'longterm': { + 'direction': longterm_direction, + 'confidence': longterm_conf, + 'change': price_changes[2] + } + } + # Get the action with highest Q-value action = action_probs.argmax().item() - # Adjust action based on extrema prediction (with some probability) + # Adjust action based on extrema and price predictions + # Prioritize short-term movement for trading decisions + if immediate_conf > 0.8: # Only adjust for strong signals + if immediate_direction == 2: # UP prediction + # Bias toward BUY for strong up predictions + if action != 0 and random.random() < 0.3 * immediate_conf: + logger.info(f"Adjusting action to BUY based on immediate UP prediction") + action = 0 # BUY + elif immediate_direction == 0: # DOWN prediction + # Bias toward SELL for strong down predictions + if action != 1 and random.random() < 0.3 * immediate_conf: + logger.info(f"Adjusting action to SELL based on immediate DOWN prediction") + action = 1 # SELL + + # Also consider extrema detection for action adjustment if extrema_confidence > 0.8: # Only adjust for strong signals if extrema_class == 0: # Bottom detected # Bias toward BUY at bottoms @@ -307,53 +426,102 @@ class DQNAgent: next_states_tensor, dones_tensor ) - # Occasionally train specifically on extrema points, if we have enough - if hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2: - if random.random() < 0.3: # 30% chance to do extra extrema training - # Sample from extrema memory - extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory)) - extrema_batch = random.sample(self.extrema_memory, extrema_batch_size) + # Training focus selector - randomly focus on one of the specialized training types + training_focus = random.random() + + # Occasionally train specifically on extrema points + if training_focus < 0.3 and hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2: + # Sample from extrema memory + extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory)) + extrema_batch = random.sample(self.extrema_memory, extrema_batch_size) + + # Extract batches with proper tensor conversion + extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch]) + extrema_actions = np.array([x[1] for x in extrema_batch]) + extrema_rewards = np.array([x[2] for x in extrema_batch]) + extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch]) + extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32) + + # Convert to torch tensors and move to device + extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device) + extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device) + extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device) + extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device) + extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device) + + # Additional training step focused on extrema points (with smaller learning rate) + original_lr = self.optimizer.param_groups[0]['lr'] + # Temporarily reduce learning rate for fine-tuning on extrema + for param_group in self.optimizer.param_groups: + param_group['lr'] = original_lr * 0.5 - # Extract batches with proper tensor conversion - extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch]) - extrema_actions = np.array([x[1] for x in extrema_batch]) - extrema_rewards = np.array([x[2] for x in extrema_batch]) - extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch]) - extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32) + # Train on extrema + if self.use_mixed_precision: + extrema_loss = self._replay_mixed_precision( + extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, + extrema_next_states_tensor, extrema_dones_tensor + ) + else: + extrema_loss = self._replay_standard( + extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, + extrema_next_states_tensor, extrema_dones_tensor + ) + + # Restore original learning rate + for param_group in self.optimizer.param_groups: + param_group['lr'] = original_lr + + logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}") + + # Average the loss + loss = (loss + extrema_loss) / 2 + + # Occasionally train specifically on price movement data + elif training_focus >= 0.3 and training_focus < 0.6 and hasattr(self, 'price_movement_memory') and len(self.price_movement_memory) >= self.batch_size // 2: + # Sample from price movement memory + price_batch_size = min(self.batch_size // 2, len(self.price_movement_memory)) + price_batch = random.sample(self.price_movement_memory, price_batch_size) + + # Extract batches with proper tensor conversion + price_states = np.vstack([self._normalize_state(x[0]) for x in price_batch]) + price_actions = np.array([x[1] for x in price_batch]) + price_rewards = np.array([x[2] for x in price_batch]) + price_next_states = np.vstack([self._normalize_state(x[3]) for x in price_batch]) + price_dones = np.array([x[4] for x in price_batch], dtype=np.float32) + + # Convert to torch tensors and move to device + price_states_tensor = torch.FloatTensor(price_states).to(self.device) + price_actions_tensor = torch.LongTensor(price_actions).to(self.device) + price_rewards_tensor = torch.FloatTensor(price_rewards).to(self.device) + price_next_states_tensor = torch.FloatTensor(price_next_states).to(self.device) + price_dones_tensor = torch.FloatTensor(price_dones).to(self.device) + + # Additional training step focused on price movements (with smaller learning rate) + original_lr = self.optimizer.param_groups[0]['lr'] + # Temporarily reduce learning rate + for param_group in self.optimizer.param_groups: + param_group['lr'] = original_lr * 0.5 - # Convert to torch tensors and move to device - extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device) - extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device) - extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device) - extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device) - extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device) - - # Additional training step focused on extrema points (with smaller learning rate) - original_lr = self.optimizer.param_groups[0]['lr'] - # Temporarily reduce learning rate for fine-tuning on extrema - for param_group in self.optimizer.param_groups: - param_group['lr'] = original_lr * 0.5 - - # Train on extrema - if self.use_mixed_precision: - extrema_loss = self._replay_mixed_precision( - extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, - extrema_next_states_tensor, extrema_dones_tensor - ) - else: - extrema_loss = self._replay_standard( - extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, - extrema_next_states_tensor, extrema_dones_tensor - ) - - # Restore original learning rate - for param_group in self.optimizer.param_groups: - param_group['lr'] = original_lr - - logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}") - - # Average the loss - loss = (loss + extrema_loss) / 2 + # Train on price movement data + if self.use_mixed_precision: + price_loss = self._replay_mixed_precision( + price_states_tensor, price_actions_tensor, price_rewards_tensor, + price_next_states_tensor, price_dones_tensor + ) + else: + price_loss = self._replay_standard( + price_states_tensor, price_actions_tensor, price_rewards_tensor, + price_next_states_tensor, price_dones_tensor + ) + + # Restore original learning rate + for param_group in self.optimizer.param_groups: + param_group['lr'] = original_lr + + logger.info(f"Extra training on price movement data: loss={price_loss:.4f}") + + # Average the loss + loss = (loss + price_loss) / 2 # Store and return loss self.losses.append(loss) @@ -365,12 +533,12 @@ class DQNAgent: self.optimizer.zero_grad() # Get current Q values and extrema predictions - current_q_values, current_extrema_pred = self.policy_net(states) + current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Get next Q values from target network with torch.no_grad(): - next_q_values, next_extrema_pred = self.target_net(next_states) + next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states) next_q_values = next_q_values.max(1)[0] # Check for dimension mismatch and fix it @@ -389,13 +557,10 @@ class DQNAgent: # Compute Q-value loss (primary task) q_loss = nn.MSELoss()(current_q_values, target_q_values) - # Create extrema labels from price movements (crude approximation) - # If the next state price is higher than current, we might be in an uptrend (not a bottom) - # If the next state price is lower than current, we might be in a downtrend (not a top) - # This is a simplified approximation; in real scenarios we'd want to use actual extrema detection + # Initialize combined loss with Q-value loss + loss = q_loss # Try to extract price from current and next states - # Assuming price is in the last feature try: # Extract price feature from sequence data (if available) if len(states.shape) == 3: # [batch, seq, features] @@ -405,43 +570,99 @@ class DQNAgent: current_prices = states[:, -1] # Last feature next_prices = next_states[:, -1] - # Compute price changes - price_changes = (next_prices - current_prices) / current_prices + # Compute price changes for different timeframes + immediate_changes = (next_prices - current_prices) / current_prices - # Create crude extrema labels: - # 0 = bottom: Large negative price change followed by positive change - # 1 = top: Large positive price change followed by negative change - # 2 = neither: Small or inconsistent changes + # Create price direction labels - simplified for training + # 0 = down, 1 = sideways, 2 = up + immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways + midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 + longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 - # Classify based on price change magnitude - extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither + # Immediate term direction (1s, 1m) + immediate_up = (immediate_changes > 0.0005) + immediate_down = (immediate_changes < -0.0005) + immediate_labels[immediate_up] = 2 # Up + immediate_labels[immediate_down] = 0 # Down - # Identify potential bottoms (significant negative change) - bottoms = (price_changes < -0.003) - extrema_labels[bottoms] = 0 + # For mid and long term, we can only approximate during training + # In a real system, we'd need historical data to validate these + # Here we'll use the immediate term with increasing thresholds as approximation - # Identify potential tops (significant positive change) - tops = (price_changes > 0.003) - extrema_labels[tops] = 1 + # Mid-term (1h) - use slightly higher threshold + midterm_up = (immediate_changes > 0.001) + midterm_down = (immediate_changes < -0.001) + midterm_labels[midterm_up] = 2 # Up + midterm_labels[midterm_down] = 0 # Down - # Calculate extrema prediction loss (auxiliary task) - if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: - current_extrema_pred = current_extrema_pred[:min_size] - extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) + # Long-term (1d) - use even higher threshold + longterm_up = (immediate_changes > 0.002) + longterm_down = (immediate_changes < -0.002) + longterm_labels[longterm_up] = 2 # Up + longterm_labels[longterm_down] = 0 # Down + + # Generate target values for price change regression + # For simplicity, we'll use the immediate change and scaled versions for longer timeframes + price_value_targets = torch.zeros((min_size, 4), device=self.device) + price_value_targets[:, 0] = immediate_changes + price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change + price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change + price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change + + # Calculate loss for price direction prediction (classification) + if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size: + # Slice predictions to match the adjusted batch size + immediate_pred = current_price_pred['immediate'][:min_size] + midterm_pred = current_price_pred['midterm'][:min_size] + longterm_pred = current_price_pred['longterm'][:min_size] + price_values_pred = current_price_pred['values'][:min_size] - # Combined loss (primary + auxiliary with lower weight) - # Typically auxiliary tasks should have lower weight to not dominate the primary task - loss = q_loss + 0.3 * extrema_loss + # Compute losses for each task + immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels) + midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels) + longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels) - # Log separate loss components occasionally - if random.random() < 0.01: # Log 1% of the time to avoid flood - logger.info(f"Training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}") - else: - # Fall back to just Q-value loss if extrema predictions aren't available - loss = q_loss + # MSE loss for price value regression + price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets) + + # Combine all price prediction losses + price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss + + # Create extrema labels (same as before) + extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither + + # Identify potential bottoms (significant negative change) + bottoms = (immediate_changes < -0.003) + extrema_labels[bottoms] = 0 + + # Identify potential tops (significant positive change) + tops = (immediate_changes > 0.003) + extrema_labels[tops] = 1 + + # Calculate extrema prediction loss + if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: + current_extrema_pred = current_extrema_pred[:min_size] + extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) + + # Combined loss with all components + # Primary task: Q-value learning (RL objective) + # Secondary tasks: extrema detection and price prediction (supervised objectives) + loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss + + # Log loss components occasionally + if random.random() < 0.01: # Log 1% of the time + logger.info( + f"Training losses: Q-loss={q_loss.item():.4f}, " + f"Extrema-loss={extrema_loss.item():.4f}, " + f"Price-loss={price_loss.item():.4f}, " + f"Imm-loss={immediate_loss.item():.4f}, " + f"Mid-loss={midterm_loss.item():.4f}, " + f"Long-loss={longterm_loss.item():.4f}" + ) except Exception as e: # Fallback if price extraction fails - logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.") + logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.") + # Just use Q-value loss loss = q_loss # Backward pass and optimize @@ -475,12 +696,12 @@ class DQNAgent: # Forward pass with amp autocasting with torch.cuda.amp.autocast(): # Get current Q values and extrema predictions - current_q_values, current_extrema_pred = self.policy_net(states) + current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Get next Q values from target network with torch.no_grad(): - next_q_values, next_extrema_pred = self.target_net(next_states) + next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states) next_q_values = next_q_values.max(1)[0] # Check for dimension mismatch and fix it @@ -499,7 +720,9 @@ class DQNAgent: # Compute Q-value loss (primary task) q_loss = nn.MSELoss()(current_q_values, target_q_values) - # Create extrema labels from price movements (crude approximation) + # Initialize loss with q_loss + loss = q_loss + # Try to extract price from current and next states try: # Extract price feature from sequence data (if available) @@ -510,42 +733,96 @@ class DQNAgent: current_prices = states[:, -1] # Last feature next_prices = next_states[:, -1] - # Compute price changes - price_changes = (next_prices - current_prices) / current_prices + # Compute price changes for different timeframes + immediate_changes = (next_prices - current_prices) / current_prices - # Create crude extrema labels: - # 0 = bottom: Large negative price change followed by positive change - # 1 = top: Large positive price change followed by negative change - # 2 = neither: Small or inconsistent changes + # Create price direction labels - simplified for training + # 0 = down, 1 = sideways, 2 = up + immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways + midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 + longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 - # Classify based on price change magnitude - extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither + # Immediate term direction (1s, 1m) + immediate_up = (immediate_changes > 0.0005) + immediate_down = (immediate_changes < -0.0005) + immediate_labels[immediate_up] = 2 # Up + immediate_labels[immediate_down] = 0 # Down - # Identify potential bottoms (significant negative change) - bottoms = (price_changes < -0.003) - extrema_labels[bottoms] = 0 + # For mid and long term, we can only approximate during training + # In a real system, we'd need historical data to validate these + # Here we'll use the immediate term with increasing thresholds as approximation - # Identify potential tops (significant positive change) - tops = (price_changes > 0.003) - extrema_labels[tops] = 1 + # Mid-term (1h) - use slightly higher threshold + midterm_up = (immediate_changes > 0.001) + midterm_down = (immediate_changes < -0.001) + midterm_labels[midterm_up] = 2 # Up + midterm_labels[midterm_down] = 0 # Down - # Calculate extrema prediction loss (auxiliary task) - if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: - current_extrema_pred = current_extrema_pred[:min_size] - extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) + # Long-term (1d) - use even higher threshold + longterm_up = (immediate_changes > 0.002) + longterm_down = (immediate_changes < -0.002) + longterm_labels[longterm_up] = 2 # Up + longterm_labels[longterm_down] = 0 # Down + + # Generate target values for price change regression + # For simplicity, we'll use the immediate change and scaled versions for longer timeframes + price_value_targets = torch.zeros((min_size, 4), device=self.device) + price_value_targets[:, 0] = immediate_changes + price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change + price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change + price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change + + # Calculate loss for price direction prediction (classification) + if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size: + # Slice predictions to match the adjusted batch size + immediate_pred = current_price_pred['immediate'][:min_size] + midterm_pred = current_price_pred['midterm'][:min_size] + longterm_pred = current_price_pred['longterm'][:min_size] + price_values_pred = current_price_pred['values'][:min_size] - # Combined loss (primary + auxiliary with lower weight) - loss = q_loss + 0.3 * extrema_loss + # Compute losses for each task + immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels) + midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels) + longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels) - # Log separate loss components occasionally - if random.random() < 0.01: # Log 1% of the time to avoid flood - logger.info(f"Mixed precision training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}") - else: - # Fall back to just Q-value loss - loss = q_loss + # MSE loss for price value regression + price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets) + + # Combine all price prediction losses + price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss + + # Create extrema labels (same as before) + extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither + + # Identify potential bottoms (significant negative change) + bottoms = (immediate_changes < -0.003) + extrema_labels[bottoms] = 0 + + # Identify potential tops (significant positive change) + tops = (immediate_changes > 0.003) + extrema_labels[tops] = 1 + + # Calculate extrema prediction loss + if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: + current_extrema_pred = current_extrema_pred[:min_size] + extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) + + # Combined loss with all components + # Primary task: Q-value learning (RL objective) + # Secondary tasks: extrema detection and price prediction (supervised objectives) + loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss + + # Log loss components occasionally + if random.random() < 0.01: # Log 1% of the time + logger.info( + f"Mixed precision losses: Q-loss={q_loss.item():.4f}, " + f"Extrema-loss={extrema_loss.item():.4f}, " + f"Price-loss={price_loss.item():.4f}" + ) except Exception as e: # Fallback if price extraction fails - logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.") + logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.") + # Just use Q-value loss loss = q_loss # Backward pass with scaled gradients diff --git a/NN/models/simple_cnn.py b/NN/models/simple_cnn.py index 19a8292..1bb91ea 100644 --- a/NN/models/simple_cnn.py +++ b/NN/models/simple_cnn.py @@ -125,6 +125,14 @@ class SimpleCNN(nn.Module): # Extrema detection head self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither + + # Price prediction heads for different timeframes + self.price_pred_immediate = nn.Linear(256, 3) # Up, Down, Sideways for immediate term (1s, 1m) + self.price_pred_midterm = nn.Linear(256, 3) # Up, Down, Sideways for mid-term (1h) + self.price_pred_longterm = nn.Linear(256, 3) # Up, Down, Sideways for long-term (1d) + + # Regression heads for exact price prediction + self.price_pred_value = nn.Linear(256, 4) # Predicts % change for each timeframe (1s, 1m, 1h, 1d) def _check_rebuild_network(self, features): """Check if network needs to be rebuilt for different feature dimensions""" @@ -140,7 +148,7 @@ class SimpleCNN(nn.Module): def forward(self, x): """ Forward pass through the network - Returns both action values and extrema predictions + Returns action values, extrema predictions, and price movement predictions for multiple timeframes """ # Handle different input shapes if len(x.shape) == 2: # [batch_size, features] @@ -173,7 +181,50 @@ class SimpleCNN(nn.Module): # Extrema predictions extrema_pred = self.extrema_head(fc_out) - return action_values, extrema_pred + # Price movement predictions for different timeframes + price_immediate = self.price_pred_immediate(fc_out) # 1s, 1m + price_midterm = self.price_pred_midterm(fc_out) # 1h + price_longterm = self.price_pred_longterm(fc_out) # 1d + + # Regression values for exact price predictions (percentage changes) + price_values = self.price_pred_value(fc_out) + + # Return all predictions in a structured dictionary + price_predictions = { + 'immediate': price_immediate, + 'midterm': price_midterm, + 'longterm': price_longterm, + 'values': price_values + } + + return action_values, extrema_pred, price_predictions + + def save(self, path): + """Save model weights and architecture""" + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.save({ + 'state_dict': self.state_dict(), + 'input_shape': self.input_shape, + 'n_actions': self.n_actions, + 'feature_dim': self.feature_dim + }, f"{path}.pt") + logger.info(f"Model saved to {path}.pt") + + def load(self, path): + """Load model weights and architecture""" + try: + checkpoint = torch.load(f"{path}.pt", map_location=self.device) + self.input_shape = checkpoint['input_shape'] + self.n_actions = checkpoint['n_actions'] + self.feature_dim = checkpoint['feature_dim'] + self._build_network() + self.load_state_dict(checkpoint['state_dict']) + self.to(self.device) + logger.info(f"Model loaded from {path}.pt") + return True + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + return False class CNNModelPyTorch(nn.Module): """ diff --git a/NN/train_rl.py b/NN/train_rl.py index 807b5a8..6cff79c 100644 --- a/NN/train_rl.py +++ b/NN/train_rl.py @@ -9,6 +9,9 @@ import sys import pandas as pd import gym import json +import random +import torch.nn as nn +import contextlib # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -49,19 +52,36 @@ class RLTradingEnvironment(gym.Env): Reinforcement Learning environment for trading with technical indicators from multiple timeframes """ - def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15): + def __init__(self, features_1m, features_1h=None, features_1d=None, window_size=20, trading_fee=0.0025, min_trade_interval=15): super().__init__() # Initialize attributes before parent class self.window_size = window_size self.num_features = features_1m.shape[1] - 1 # Exclude close price - self.num_timeframes = 3 # 1m, 5m, 15m + + # Count available timeframes + self.num_timeframes = 1 # Always have 1m + if features_1h is not None: + self.num_timeframes += 1 + if features_1d is not None: + self.num_timeframes += 1 + self.feature_dim = self.num_features * self.num_timeframes # Store features from different timeframes self.features_1m = features_1m - self.features_5m = features_5m - self.features_15m = features_15m + self.features_1h = features_1h + self.features_1d = features_1d + + # Create synthetic 1s data from 1m (for demo purposes) + self.features_1s = self._create_synthetic_1s_data(features_1m) + + # If higher timeframes are missing, create synthetic data + if self.features_1h is None: + self.features_1h = self._create_synthetic_hourly_data(features_1m) + + if self.features_1d is None: + self.features_1d = self._create_synthetic_daily_data(features_1h) # Trading parameters self.initial_balance = 1.0 @@ -83,6 +103,45 @@ class RLTradingEnvironment(gym.Env): # Callback for visualization or external monitoring self.action_callback = None + def _create_synthetic_1s_data(self, features_1m): + """Create synthetic 1-second data from 1-minute data""" + # Simple approach: duplicate each 1m candle for 60 seconds with some noise + num_samples = features_1m.shape[0] + synthetic_1s = np.zeros((num_samples * 60, features_1m.shape[1])) + + for i in range(num_samples): + for j in range(60): + idx = i * 60 + j + if idx < synthetic_1s.shape[0]: + # Copy the 1m data with small random noise + synthetic_1s[idx] = features_1m[i] * (1 + np.random.normal(0, 0.0001, features_1m.shape[1])) + + return synthetic_1s + + def _create_synthetic_hourly_data(self, features_1m): + """Create synthetic hourly data from minute data""" + # Group by hour, taking every 60th candle + num_samples = features_1m.shape[0] // 60 + synthetic_1h = np.zeros((num_samples, features_1m.shape[1])) + + for i in range(num_samples): + if i * 60 < features_1m.shape[0]: + synthetic_1h[i] = features_1m[i * 60] + + return synthetic_1h + + def _create_synthetic_daily_data(self, features_1h): + """Create synthetic daily data from hourly data""" + # Group by day, taking every 24th candle + num_samples = features_1h.shape[0] // 24 + synthetic_1d = np.zeros((num_samples, features_1h.shape[1])) + + for i in range(num_samples): + if i * 24 < features_1h.shape[0]: + synthetic_1d[i] = features_1h[i * 24] + + return synthetic_1d + def reset(self): """Reset the environment to initial state""" self.balance = self.initial_balance @@ -104,35 +163,43 @@ class RLTradingEnvironment(gym.Env): Combine features from multiple timeframes, reshaped for the CNN. """ # Calculate indices for each timeframe - idx_1m = self.current_step - idx_5m = idx_1m // 5 - idx_15m = idx_1m // 15 + idx_1m = min(self.current_step, self.features_1m.shape[0] - 1) + idx_1h = idx_1m // 60 # 60 minutes in an hour + idx_1d = idx_1h // 24 # 24 hours in a day + + # Cap indices to prevent out of bounds + idx_1h = min(idx_1h, self.features_1h.shape[0] - 1) + idx_1d = min(idx_1d, self.features_1d.shape[0] - 1) # Extract feature windows from each timeframe - window_1m = self.features_1m[idx_1m - self.window_size:idx_1m] + window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m] - # Handle 5m timeframe - start_5m = max(0, idx_5m - self.window_size) - window_5m = self.features_5m[start_5m:idx_5m] + # Handle hourly timeframe + start_1h = max(0, idx_1h - self.window_size) + window_1h = self.features_1h[start_1h:idx_1h] - # Handle 15m timeframe - start_15m = max(0, idx_15m - self.window_size) - window_15m = self.features_15m[start_15m:idx_15m] + # Handle daily timeframe + start_1d = max(0, idx_1d - self.window_size) + window_1d = self.features_1d[start_1d:idx_1d] - # Pad if needed (for 5m and 15m) - if len(window_5m) < self.window_size: - padding = np.zeros((self.window_size - len(window_5m), window_5m.shape[1])) - window_5m = np.vstack([padding, window_5m]) + # Pad if needed (for higher timeframes) + if len(window_1m) < self.window_size: + padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1])) + window_1m = np.vstack([padding, window_1m]) - if len(window_15m) < self.window_size: - padding = np.zeros((self.window_size - len(window_15m), window_15m.shape[1])) - window_15m = np.vstack([padding, window_15m]) + if len(window_1h) < self.window_size: + padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1])) + window_1h = np.vstack([padding, window_1h]) + + if len(window_1d) < self.window_size: + padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1])) + window_1d = np.vstack([padding, window_1d]) # Combine features from all timeframes combined_features = np.hstack([ window_1m.reshape(self.window_size, -1), - window_5m.reshape(self.window_size, -1), - window_15m.reshape(self.window_size, -1) + window_1h.reshape(self.window_size, -1), + window_1d.reshape(self.window_size, -1) ]) # Convert to float32 and handle any NaN values @@ -152,7 +219,14 @@ class RLTradingEnvironment(gym.Env): """ # Get current and next price current_price = self.features_1m[self.current_step, -1] # Close price is last column - next_price = self.features_1m[self.current_step + 1, -1] + + # Check if we're at the end of the data + if self.current_step + 1 >= len(self.features_1m): + next_price = current_price # Use current price if at the end + done = True + else: + next_price = self.features_1m[self.current_step + 1, -1] + done = False # Handle zero or negative prices if current_price <= 0: @@ -164,7 +238,6 @@ class RLTradingEnvironment(gym.Env): # Default reward is slightly negative to discourage inaction reward = -0.0001 - done = False profit_pct = None # Initialize profit_pct variable # Check if enough time has passed since last trade @@ -225,7 +298,7 @@ class RLTradingEnvironment(gym.Env): # Move to next step self.current_step += 1 - # Check if done + # Check if done (reached end of data) if self.current_step >= len(self.features_1m) - 1: done = True @@ -251,6 +324,26 @@ class RLTradingEnvironment(gym.Env): gain = (total_value - self.initial_balance) / self.initial_balance self.win_rate = self.wins / max(1, self.trades) + # Check if we have prediction data for future timeframes + future_price_1h = None + future_price_1d = None + + # Get hourly index + idx_1h = self.current_step // 60 + if idx_1h + 1 < len(self.features_1h): + hourly_close_idx = self.features_1h.shape[1] - 1 # Assuming close is last column + current_1h_price = self.features_1h[idx_1h, hourly_close_idx] + next_1h_price = self.features_1h[idx_1h + 1, hourly_close_idx] + future_price_1h = (next_1h_price - current_1h_price) / current_1h_price + + # Get daily index + idx_1d = idx_1h // 24 + if idx_1d + 1 < len(self.features_1d): + daily_close_idx = self.features_1d.shape[1] - 1 # Assuming close is last column + current_1d_price = self.features_1d[idx_1d, daily_close_idx] + next_1d_price = self.features_1d[idx_1d + 1, daily_close_idx] + future_price_1d = (next_1d_price - current_1d_price) / current_1d_price + info = { 'balance': self.balance, 'position': self.position, @@ -260,7 +353,9 @@ class RLTradingEnvironment(gym.Env): 'win_rate': self.win_rate, 'profit_pct': profit_pct if action == 1 and self.position == 0 else None, 'current_price': current_price, - 'next_price': next_price + 'next_price': next_price, + 'future_price_1h': future_price_1h, # Actual future hourly price change + 'future_price_1d': future_price_1d # Actual future daily price change } # Call the callback if it exists @@ -279,7 +374,8 @@ class RLTradingEnvironment(gym.Env): self.action_callback = callback def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent", - action_callback=None, episode_callback=None, symbol="BTC/USDT"): + action_callback=None, episode_callback=None, symbol="BTC/USDT", + pretrain_price_prediction_enabled=True, pretrain_epochs=10): """ Train a reinforcement learning agent for trading @@ -291,6 +387,8 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo action_callback: Callback function for monitoring actions episode_callback: Callback function for monitoring episodes symbol: Trading symbol to use + pretrain_price_prediction_enabled: Whether to pre-train price prediction + pretrain_epochs: Number of epochs for pre-training Returns: tuple: (trained agent, environment) @@ -396,6 +494,30 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo else: logger.info("No existing model found. Starting with a new model.") + # Pre-train price prediction if enabled and we have a new model + if pretrain_price_prediction_enabled: + if not os.path.exists(model_file) or input("Pre-train price prediction? (y/n): ").lower() == 'y': + logger.info("Pre-training price prediction capability...") + # Attempt to load hourly and daily data for pre-training + try: + data_interface.add_timeframe('1h') + data_interface.add_timeframe('1d') + + # Run pre-training + agent = pretrain_price_prediction( + agent=agent, + data_interface=data_interface, + n_epochs=pretrain_epochs, + batch_size=128 + ) + + # Save the pre-trained model + agent.save(f"{save_path}_pretrained") + logger.info("Pre-trained model saved.") + except Exception as e: + logger.error(f"Error during pre-training: {e}") + logger.warning("Continuing with RL training without pre-training.") + # Create TensorBoard writer writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}') @@ -499,5 +621,379 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo return agent, env +def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20): + """ + Generate labeled training data for price prediction at different timeframes + + Args: + data_1m: DataFrame with 1-minute data + data_1h: DataFrame with 1-hour data + data_1d: DataFrame with 1-day data + window_size: Size of the input window + + Returns: + tuple: (X, y_immediate, y_midterm, y_longterm, y_values) + - X: input features (window sequences) + - y_immediate: immediate direction labels (0=down, 1=sideways, 2=up) + - y_midterm: mid-term direction labels + - y_longterm: long-term direction labels + - y_values: actual percentage changes for each timeframe + """ + logger.info("Generating price prediction training data from historical prices") + + # Prepare data structures + X = [] + y_immediate = [] # 1m + y_midterm = [] # 1h + y_longterm = [] # 1d + y_values = [] # Actual percentage changes + + # Calculate future returns for labeling + data_1m['future_return_1m'] = data_1m['close'].pct_change(1).shift(-1) # Next candle + data_1m['future_return_10m'] = data_1m['close'].pct_change(10).shift(-10) # Next 10 candles + + # Add indices to align data + data_1m['index'] = range(len(data_1m)) + data_1h['index'] = range(len(data_1h)) + data_1d['index'] = range(len(data_1d)) + + # Define thresholds for direction labels + immediate_threshold = 0.0005 + midterm_threshold = 0.001 + longterm_threshold = 0.002 + + # Loop through 1m data to create training samples + max_idx = len(data_1m) - window_size - 10 # Ensure we have future data for labels + sample_indices = random.sample(range(window_size, max_idx), min(10000, max_idx - window_size)) + + for idx in sample_indices: + # Get window of 1m data + window_1m = data_1m.iloc[idx-window_size:idx].drop(['timestamp', 'future_return_1m', 'future_return_10m', 'index'], axis=1, errors='ignore') + + # Skip if window contains NaN + if window_1m.isnull().values.any(): + continue + + # Get future returns for labeling + future_return_1m = data_1m.iloc[idx]['future_return_1m'] + future_return_10m = data_1m.iloc[idx]['future_return_10m'] + + # Find corresponding row in 1h data (closest timestamp) + current_timestamp = data_1m.iloc[idx]['timestamp'] + + # Find 1h candle for mid-term prediction + if 'timestamp' in data_1h.columns: + # Find closest 1h candle + closest_1h_idx = data_1h['timestamp'].searchsorted(current_timestamp) + if closest_1h_idx >= len(data_1h): + closest_1h_idx = len(data_1h) - 1 + + # Get future 1h return (next candle) + if closest_1h_idx < len(data_1h) - 1: + future_return_1h = (data_1h.iloc[closest_1h_idx + 1]['close'] - data_1h.iloc[closest_1h_idx]['close']) / data_1h.iloc[closest_1h_idx]['close'] + else: + future_return_1h = 0 + else: + future_return_1h = future_return_10m # Fallback + + # Find 1d candle for long-term prediction + if 'timestamp' in data_1d.columns: + # Find closest 1d candle + closest_1d_idx = data_1d['timestamp'].searchsorted(current_timestamp) + if closest_1d_idx >= len(data_1d): + closest_1d_idx = len(data_1d) - 1 + + # Get future 1d return (next candle) + if closest_1d_idx < len(data_1d) - 1: + future_return_1d = (data_1d.iloc[closest_1d_idx + 1]['close'] - data_1d.iloc[closest_1d_idx]['close']) / data_1d.iloc[closest_1d_idx]['close'] + else: + future_return_1d = 0 + else: + future_return_1d = future_return_1h * 2 # Fallback + + # Create direction labels + # 0=down, 1=sideways, 2=up + + # Immediate (1m) + if future_return_1m > immediate_threshold: + immediate_label = 2 # UP + elif future_return_1m < -immediate_threshold: + immediate_label = 0 # DOWN + else: + immediate_label = 1 # SIDEWAYS + + # Mid-term (1h) + if future_return_1h > midterm_threshold: + midterm_label = 2 # UP + elif future_return_1h < -midterm_threshold: + midterm_label = 0 # DOWN + else: + midterm_label = 1 # SIDEWAYS + + # Long-term (1d) + if future_return_1d > longterm_threshold: + longterm_label = 2 # UP + elif future_return_1d < -longterm_threshold: + longterm_label = 0 # DOWN + else: + longterm_label = 1 # SIDEWAYS + + # Store data + X.append(window_1m.values) + y_immediate.append(immediate_label) + y_midterm.append(midterm_label) + y_longterm.append(longterm_label) + y_values.append([future_return_1m, future_return_1h, future_return_1d, future_return_1d * 1.5]) # Add weekly estimate + + # Convert to numpy arrays + X = np.array(X) + y_immediate = np.array(y_immediate) + y_midterm = np.array(y_midterm) + y_longterm = np.array(y_longterm) + y_values = np.array(y_values) + + logger.info(f"Generated {len(X)} price prediction training samples") + + # Log class distribution + for name, y in [("Immediate", y_immediate), ("Mid-term", y_midterm), ("Long-term", y_longterm)]: + down = (y == 0).sum() + sideways = (y == 1).sum() + up = (y == 2).sum() + logger.info(f"{name} direction distribution: DOWN={down} ({down/len(y)*100:.1f}%), " + f"SIDEWAYS={sideways} ({sideways/len(y)*100:.1f}%), " + f"UP={up} ({up/len(y)*100:.1f}%)") + + return X, y_immediate, y_midterm, y_longterm, y_values + +def pretrain_price_prediction(agent, data_interface, n_epochs=10, batch_size=128): + """ + Pre-train the agent's price prediction capability on historical data + + Args: + agent: DQNAgent instance to train + data_interface: DataInterface instance for accessing data + n_epochs: Number of epochs for training + batch_size: Batch size for training + + Returns: + The agent with pre-trained price prediction capabilities + """ + logger.info("Starting supervised pre-training of price prediction") + + try: + # Load data for all required timeframes + data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=10000) + data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000) + data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500) + + # Check if data is available + if data_1m is None: + logger.warning("1m data not available for pre-training") + return agent + + if data_1h is None: + logger.warning("1h data not available, using synthesized data") + # Create synthetic 1h data from 1m data + data_1h = data_1m.iloc[::60].reset_index(drop=True).copy() # Take every 60th record + + if data_1d is None: + logger.warning("1d data not available, using synthesized data") + # Create synthetic 1d data from 1h data + data_1d = data_1h.iloc[::24].reset_index(drop=True).copy() # Take every 24th record + + # Add technical indicators to all data + data_1m = data_interface.add_technical_indicators(data_1m) + data_1h = data_interface.add_technical_indicators(data_1h) + data_1d = data_interface.add_technical_indicators(data_1d) + + # Generate labeled training data + X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data( + data_1m, data_1h, data_1d, window_size=20 + ) + + # Split data into training and validation sets + from sklearn.model_selection import train_test_split + X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split( + X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42 + ) + + # Convert to torch tensors + X_train_tensor = torch.FloatTensor(X_train).to(agent.device) + y_imm_train_tensor = torch.LongTensor(y_imm_train).to(agent.device) + y_mid_train_tensor = torch.LongTensor(y_mid_train).to(agent.device) + y_long_train_tensor = torch.LongTensor(y_long_train).to(agent.device) + y_val_train_tensor = torch.FloatTensor(y_val_train).to(agent.device) + + X_val_tensor = torch.FloatTensor(X_val).to(agent.device) + y_imm_val_tensor = torch.LongTensor(y_imm_val).to(agent.device) + y_mid_val_tensor = torch.LongTensor(y_mid_val).to(agent.device) + y_long_val_tensor = torch.LongTensor(y_long_val).to(agent.device) + y_val_val_tensor = torch.FloatTensor(y_val_val).to(agent.device) + + # Calculate class weights for imbalanced data + from torch.nn.functional import one_hot + + # Function to calculate class weights + def get_class_weights(labels): + counts = np.bincount(labels) + if len(counts) < 3: # Ensure we have 3 classes + counts = np.append(counts, [0] * (3 - len(counts))) + weights = 1.0 / np.array(counts) + weights = weights / np.sum(weights) # Normalize + return weights + + imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(agent.device) + mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(agent.device) + long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(agent.device) + + # Create DataLoader for batch training + from torch.utils.data import TensorDataset, DataLoader + + train_dataset = TensorDataset( + X_train_tensor, y_imm_train_tensor, y_mid_train_tensor, + y_long_train_tensor, y_val_train_tensor + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + # Set up loss functions with class weights + imm_criterion = nn.CrossEntropyLoss(weight=imm_weights) + mid_criterion = nn.CrossEntropyLoss(weight=mid_weights) + long_criterion = nn.CrossEntropyLoss(weight=long_weights) + value_criterion = nn.MSELoss() + + # Set up optimizer (separate from agent's optimizer) + pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002) + pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True + ) + + # Set model to training mode + agent.policy_net.train() + + # Training loop + best_val_loss = float('inf') + patience = 5 + patience_counter = 0 + + for epoch in range(n_epochs): + # Training phase + train_loss = 0.0 + imm_correct, mid_correct, long_correct = 0, 0, 0 + total = 0 + + for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader: + # Zero gradients + pretrain_optimizer.zero_grad() + + # Forward pass - we only need the price predictions + with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext(): + _, _, price_preds = agent.policy_net(X_batch) + + # Calculate losses for each prediction head + imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch) + mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch) + long_loss = long_criterion(price_preds['longterm'], y_long_batch) + value_loss = value_criterion(price_preds['values'], y_val_batch) + + # Combined loss (weighted by importance) + total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss + + # Backward pass and optimize + if agent.use_mixed_precision: + agent.scaler.scale(total_loss).backward() + agent.scaler.unscale_(pretrain_optimizer) + torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0) + agent.scaler.step(pretrain_optimizer) + agent.scaler.update() + else: + total_loss.backward() + torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0) + pretrain_optimizer.step() + + # Accumulate metrics + train_loss += total_loss.item() + total += X_batch.size(0) + + # Calculate accuracy + _, imm_pred = torch.max(price_preds['immediate'], 1) + _, mid_pred = torch.max(price_preds['midterm'], 1) + _, long_pred = torch.max(price_preds['longterm'], 1) + + imm_correct += (imm_pred == y_imm_batch).sum().item() + mid_correct += (mid_pred == y_mid_batch).sum().item() + long_correct += (long_pred == y_long_batch).sum().item() + + # Calculate epoch metrics + train_loss /= len(train_loader) + imm_acc = imm_correct / total + mid_acc = mid_correct / total + long_acc = long_correct / total + + # Validation phase + agent.policy_net.eval() + val_loss = 0.0 + imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0 + + with torch.no_grad(): + # Forward pass on validation data + _, _, val_price_preds = agent.policy_net(X_val_tensor) + + # Calculate validation losses + val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor) + val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor) + val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor) + val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor) + + val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss + val_loss = val_total_loss.item() + + # Calculate validation accuracy + _, imm_val_pred = torch.max(val_price_preds['immediate'], 1) + _, mid_val_pred = torch.max(val_price_preds['midterm'], 1) + _, long_val_pred = torch.max(val_price_preds['longterm'], 1) + + imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item() + mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item() + long_val_correct = (long_val_pred == y_long_val_tensor).sum().item() + + imm_val_acc = imm_val_correct / len(X_val_tensor) + mid_val_acc = mid_val_correct / len(X_val_tensor) + long_val_acc = long_val_correct / len(X_val_tensor) + + # Learning rate scheduling + pretrain_scheduler.step(val_loss) + + # Early stopping check + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + # Copy policy_net weights to target_net + agent.target_net.load_state_dict(agent.policy_net.state_dict()) + logger.info(f"Saved best model with validation loss: {val_loss:.4f}") + else: + patience_counter += 1 + if patience_counter >= patience: + logger.info(f"Early stopping triggered after {epoch+1} epochs") + break + + # Log progress + logger.info(f"Epoch {epoch+1}/{n_epochs}: " + f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, " + f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, " + f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, " + f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}") + + # Set model back to training mode for next epoch + agent.policy_net.train() + + logger.info("Price prediction pre-training complete") + return agent + + except Exception as e: + logger.error(f"Error during price prediction pre-training: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return agent + if __name__ == "__main__": train_rl() \ No newline at end of file