initial movel changes to fix performance

This commit is contained in:
Dobromir Popov
2025-04-02 14:03:20 +03:00
parent aec536d007
commit 70eb7bba9b
8 changed files with 1619 additions and 279 deletions

View File

@ -8,6 +8,7 @@ from typing import Tuple, List
import os
import sys
import logging
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
@ -20,71 +21,124 @@ logger = logging.getLogger(__name__)
class DQNAgent:
"""
Deep Q-Network agent for trading
Uses CNN model as the base network
Uses CNN model as the base network with GPU support
"""
def __init__(self,
state_size: int,
action_size: int,
window_size: int,
num_features: int,
timeframes: List[str],
state_shape: Tuple[int, ...],
n_actions: int,
learning_rate: float = 0.0005, # Reduced learning rate for more stability
gamma: float = 0.97, # Slightly reduced discount factor
epsilon: float = 1.0,
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
epsilon_decay: float = 0.9975, # Slower decay rate
memory_size: int = 20000, # Increased memory size
buffer_size: int = 20000, # Increased memory size
batch_size: int = 128, # Larger batch size
target_update: int = 5): # More frequent target updates
target_update: int = 5, # More frequent target updates
device=None): # Device for computations
self.state_size = state_size
self.action_size = action_size
self.window_size = window_size
self.num_features = num_features
self.timeframes = timeframes
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
# Multi-dimensional state (like image or sequence)
self.state_dim = state_shape
else:
# 1D state
if isinstance(state_shape, tuple):
self.state_dim = state_shape[0]
else:
self.state_dim = state_shape
# Store parameters
self.n_actions = n_actions
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.memory_size = memory_size
self.buffer_size = buffer_size
self.batch_size = batch_size
self.target_update = target_update
# Device configuration
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Set device for computation (default to CPU)
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# Initialize networks
self.policy_net = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=action_size,
timeframes=timeframes
).to(self.device)
# Initialize models with appropriate architecture based on state shape
if isinstance(self.state_dim, tuple) and len(self.state_dim) > 1:
# For image-like states (from RL environment with CNN)
from NN.models.simple_cnn import SimpleCNN
self.policy_net = SimpleCNN(self.state_dim, self.n_actions)
self.target_net = SimpleCNN(self.state_dim, self.n_actions)
else:
# For 1D state vectors (most environments)
from NN.models.simple_mlp import SimpleMLP
self.policy_net = SimpleMLP(self.state_dim, self.n_actions)
self.target_net = SimpleMLP(self.state_dim, self.n_actions)
self.target_net = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=action_size,
timeframes=timeframes
).to(self.device)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
# Initialize optimizer with gradient clipping
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate, weight_decay=1e-5)
# Set models to eval mode (important for batch norm, dropout)
self.target_net.eval()
# Initialize memories with different priorities
self.memory = deque(maxlen=memory_size)
self.extrema_memory = deque(maxlen=memory_size // 4) # For extrema points
self.positive_memory = deque(maxlen=memory_size // 4) # For positive rewards
# Optimization components
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
self.criterion = nn.MSELoss()
# Training metrics
# Experience replay memory
self.memory = []
self.positive_memory = [] # Special memory for storing good experiences
self.update_count = 0
self.losses = []
self.avg_reward = 0
self.no_improvement_count = 0
self.best_reward = float('-inf')
# Extrema detection tracking
self.last_extrema_pred = {
'class': 2, # Default to "neither" (not extrema)
'confidence': 0.0,
'raw': None
}
self.extrema_memory = [] # Special memory for storing extrema points
# Performance tracking
self.losses = []
self.avg_reward = 0.0
self.best_reward = -float('inf')
self.no_improvement_count = 0
# Check if mixed precision training should be used
self.use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0]] # Default timeframes
logger.info(f"DQN Agent using device: {self.device}")
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
self.device = device
try:
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
logger.info(f"Moved models to {self.device}")
return True
except Exception as e:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""
@ -103,25 +157,472 @@ class DQNAgent:
# Always add to main memory
self.memory.append(experience)
# Add to specialized memories if applicable
if is_extrema:
# Check if this is an extrema point based on our extrema detection head
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
# Class 0 = bottom, 1 = top, 2 = neither
# Only consider high confidence predictions
if self.last_extrema_pred['confidence'] > 0.7:
self.extrema_memory.append(experience)
# Log this special experience
extrema_type = "BOTTOM" if self.last_extrema_pred['class'] == 0 else "TOP"
logger.info(f"Stored {extrema_type} experience with reward {reward:.4f}")
# For tops and bottoms, also duplicate the experience in memory to learn more from it
for _ in range(2): # Add 2 extra copies
self.memory.append(experience)
# Explicitly marked extrema points also go to extrema memory
elif is_extrema:
self.extrema_memory.append(experience)
# Store positive experiences separately for prioritized replay
if reward > 0:
self.positive_memory.append(experience)
# For very good rewards, duplicate to learn more from them
if reward > 0.1:
for _ in range(min(int(reward * 10), 5)): # Cap at 5 extra copies for very high rewards
self.positive_memory.append(experience)
# Keep memory size under control
if len(self.memory) > self.buffer_size:
# Keep more recent experiences
self.memory = self.memory[-self.buffer_size:]
# Keep specialized memories under control too
if len(self.positive_memory) > self.buffer_size // 4:
self.positive_memory = self.positive_memory[-(self.buffer_size // 4):]
if len(self.extrema_memory) > self.buffer_size // 4:
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
def act(self, state: np.ndarray, explore=True) -> int:
"""Choose action using epsilon-greedy policy with explore flag"""
if explore and random.random() < self.epsilon:
return random.randrange(self.action_size)
return random.randrange(self.n_actions)
with torch.no_grad():
# Ensure state is normalized before inference
state_tensor = self._normalize_state(state)
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
# Get predictions using the policy network
self.policy_net.eval() # Set to evaluation mode for inference
action_probs, extrema_pred = self.policy_net(state_tensor)
return action_probs.argmax().item()
self.policy_net.train() # Back to training mode
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
extrema_class = extrema_pred.argmax(dim=1).item()
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
# Log extrema prediction for significant signals
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
# Store extrema prediction for the environment to use
self.last_extrema_pred = {
'class': extrema_class,
'confidence': extrema_confidence,
'raw': extrema_pred.cpu().numpy()
}
# Get the action with highest Q-value
action = action_probs.argmax().item()
# Adjust action based on extrema prediction (with some probability)
if extrema_confidence > 0.8: # Only adjust for strong signals
if extrema_class == 0: # Bottom detected
# Bias toward BUY at bottoms
if action != 0 and random.random() < 0.3 * extrema_confidence:
logger.info(f"Adjusting action to BUY based on bottom detection")
action = 0 # BUY
elif extrema_class == 1: # Top detected
# Bias toward SELL at tops
if action != 1 and random.random() < 0.3 * extrema_confidence:
logger.info(f"Adjusting action to SELL based on top detection")
action = 1 # SELL
return action
def replay(self, use_prioritized=True) -> float:
"""Experience replay - learn from stored experiences
Args:
use_prioritized: Whether to use prioritized experience replay
Returns:
float: Training loss
"""
# Check if we have enough samples
if len(self.memory) < self.batch_size:
return 0.0
# Check if mixed precision should be disabled
if 'DISABLE_MIXED_PRECISION' in os.environ:
self.use_mixed_precision = False
# Sample from memory with or without prioritization
if use_prioritized and len(self.positive_memory) > self.batch_size // 4:
# Use prioritized sampling: mix normal samples with positive reward samples
positive_batch_size = min(self.batch_size // 4, len(self.positive_memory))
regular_batch_size = self.batch_size - positive_batch_size
# Get positive examples
positive_batch = random.sample(self.positive_memory, positive_batch_size)
# Get regular examples
regular_batch = random.sample(self.memory, regular_batch_size)
# Combine batches
minibatch = positive_batch + regular_batch
else:
# Use regular uniform sampling
minibatch = random.sample(self.memory, self.batch_size)
# Extract batches with proper tensor conversion
states = np.vstack([self._normalize_state(x[0]) for x in minibatch])
actions = np.array([x[1] for x in minibatch])
rewards = np.array([x[2] for x in minibatch])
next_states = np.vstack([self._normalize_state(x[3]) for x in minibatch])
dones = np.array([x[4] for x in minibatch], dtype=np.float32)
# Convert to torch tensors and move to device
states_tensor = torch.FloatTensor(states).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
dones_tensor = torch.FloatTensor(dones).to(self.device)
# First training step with mixed precision if available
if self.use_mixed_precision:
loss = self._replay_mixed_precision(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
else:
loss = self._replay_standard(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
# Occasionally train specifically on extrema points, if we have enough
if hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
if random.random() < 0.3: # 30% chance to do extra extrema training
# Sample from extrema memory
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
# Extract batches with proper tensor conversion
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
extrema_actions = np.array([x[1] for x in extrema_batch])
extrema_rewards = np.array([x[2] for x in extrema_batch])
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
# Convert to torch tensors and move to device
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
# Additional training step focused on extrema points (with smaller learning rate)
original_lr = self.optimizer.param_groups[0]['lr']
# Temporarily reduce learning rate for fine-tuning on extrema
for param_group in self.optimizer.param_groups:
param_group['lr'] = original_lr * 0.5
# Train on extrema
if self.use_mixed_precision:
extrema_loss = self._replay_mixed_precision(
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
extrema_next_states_tensor, extrema_dones_tensor
)
else:
extrema_loss = self._replay_standard(
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
extrema_next_states_tensor, extrema_dones_tensor
)
# Restore original learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = original_lr
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
# Average the loss
loss = (loss + extrema_loss) / 2
# Store and return loss
self.losses.append(loss)
return loss
def _replay_standard(self, states, actions, rewards, next_states, dones):
"""Standard precision training step"""
# Zero gradients
self.optimizer.zero_grad()
# Get current Q values and extrema predictions
current_q_values, current_extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network
with torch.no_grad():
next_q_values, next_extrema_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch and fix it
if rewards.shape[0] != next_q_values.shape[0]:
# Log the shape mismatch for debugging
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
# Use the smaller size to prevent index errors
min_size = min(rewards.shape[0], next_q_values.shape[0])
rewards = rewards[:min_size]
dones = dones[:min_size]
next_q_values = next_q_values[:min_size]
current_q_values = current_q_values[:min_size]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute Q-value loss (primary task)
q_loss = nn.MSELoss()(current_q_values, target_q_values)
# Create extrema labels from price movements (crude approximation)
# If the next state price is higher than current, we might be in an uptrend (not a bottom)
# If the next state price is lower than current, we might be in a downtrend (not a top)
# This is a simplified approximation; in real scenarios we'd want to use actual extrema detection
# Try to extract price from current and next states
# Assuming price is in the last feature
try:
# Extract price feature from sequence data (if available)
if len(states.shape) == 3: # [batch, seq, features]
current_prices = states[:, -1, -1] # Last timestep, last feature
next_prices = next_states[:, -1, -1]
else: # [batch, features]
current_prices = states[:, -1] # Last feature
next_prices = next_states[:, -1]
# Compute price changes
price_changes = (next_prices - current_prices) / current_prices
# Create crude extrema labels:
# 0 = bottom: Large negative price change followed by positive change
# 1 = top: Large positive price change followed by negative change
# 2 = neither: Small or inconsistent changes
# Classify based on price change magnitude
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
# Identify potential bottoms (significant negative change)
bottoms = (price_changes < -0.003)
extrema_labels[bottoms] = 0
# Identify potential tops (significant positive change)
tops = (price_changes > 0.003)
extrema_labels[tops] = 1
# Calculate extrema prediction loss (auxiliary task)
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
current_extrema_pred = current_extrema_pred[:min_size]
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
# Combined loss (primary + auxiliary with lower weight)
# Typically auxiliary tasks should have lower weight to not dominate the primary task
loss = q_loss + 0.3 * extrema_loss
# Log separate loss components occasionally
if random.random() < 0.01: # Log 1% of the time to avoid flood
logger.info(f"Training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
else:
# Fall back to just Q-value loss if extrema predictions aren't available
loss = q_loss
except Exception as e:
# Fallback if price extraction fails
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
loss = q_loss
# Backward pass and optimize
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Track and decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
"""Mixed precision training step for better GPU performance"""
# Check if mixed precision should be explicitly disabled
if 'DISABLE_MIXED_PRECISION' in os.environ:
logger.info("Mixed precision explicitly disabled by environment variable")
return self._replay_standard(states, actions, rewards, next_states, dones)
try:
# Zero gradients
self.optimizer.zero_grad()
# Forward pass with amp autocasting
with torch.cuda.amp.autocast():
# Get current Q values and extrema predictions
current_q_values, current_extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network
with torch.no_grad():
next_q_values, next_extrema_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch and fix it
if rewards.shape[0] != next_q_values.shape[0]:
# Log the shape mismatch for debugging
logger.warning(f"Shape mismatch detected: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
# Use the smaller size to prevent index errors
min_size = min(rewards.shape[0], next_q_values.shape[0])
rewards = rewards[:min_size]
dones = dones[:min_size]
next_q_values = next_q_values[:min_size]
current_q_values = current_q_values[:min_size]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute Q-value loss (primary task)
q_loss = nn.MSELoss()(current_q_values, target_q_values)
# Create extrema labels from price movements (crude approximation)
# Try to extract price from current and next states
try:
# Extract price feature from sequence data (if available)
if len(states.shape) == 3: # [batch, seq, features]
current_prices = states[:, -1, -1] # Last timestep, last feature
next_prices = next_states[:, -1, -1]
else: # [batch, features]
current_prices = states[:, -1] # Last feature
next_prices = next_states[:, -1]
# Compute price changes
price_changes = (next_prices - current_prices) / current_prices
# Create crude extrema labels:
# 0 = bottom: Large negative price change followed by positive change
# 1 = top: Large positive price change followed by negative change
# 2 = neither: Small or inconsistent changes
# Classify based on price change magnitude
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
# Identify potential bottoms (significant negative change)
bottoms = (price_changes < -0.003)
extrema_labels[bottoms] = 0
# Identify potential tops (significant positive change)
tops = (price_changes > 0.003)
extrema_labels[tops] = 1
# Calculate extrema prediction loss (auxiliary task)
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
current_extrema_pred = current_extrema_pred[:min_size]
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
# Combined loss (primary + auxiliary with lower weight)
loss = q_loss + 0.3 * extrema_loss
# Log separate loss components occasionally
if random.random() < 0.01: # Log 1% of the time to avoid flood
logger.info(f"Mixed precision training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
else:
# Fall back to just Q-value loss
loss = q_loss
except Exception as e:
# Fallback if price extraction fails
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
loss = q_loss
# Backward pass with scaled gradients
self.scaler.scale(loss).backward()
# Gradient clipping on scaled gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
# Update with scaler
self.scaler.step(self.optimizer)
self.scaler.update()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Track and decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
except Exception as e:
logger.error(f"Error in mixed precision training: {str(e)}")
logger.warning("Falling back to standard precision training")
# Fall back to standard training
return self._replay_standard(states, actions, rewards, next_states, dones)
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training function specifically for extrema points
Args:
states: Batch of states at extrema points
actions: Batch of actions
rewards: Batch of rewards
next_states: Batch of next states
dones: Batch of done flags
Returns:
float: Training loss
"""
# Convert to numpy arrays if not already
if not isinstance(states, np.ndarray):
states = np.array(states)
if not isinstance(actions, np.ndarray):
actions = np.array(actions)
if not isinstance(rewards, np.ndarray):
rewards = np.array(rewards)
if not isinstance(next_states, np.ndarray):
next_states = np.array(next_states)
if not isinstance(dones, np.ndarray):
dones = np.array(dones, dtype=np.float32)
# Normalize states
states = np.vstack([self._normalize_state(s) for s in states])
next_states = np.vstack([self._normalize_state(s) for s in next_states])
# Convert to torch tensors and move to device
states_tensor = torch.FloatTensor(states).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
dones_tensor = torch.FloatTensor(dones).to(self.device)
# Choose training method based on precision mode
if self.use_mixed_precision:
return self._replay_mixed_precision(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
else:
return self._replay_standard(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
"""Normalize the state data to prevent numerical issues"""
@ -211,148 +712,6 @@ class DQNAgent:
return normalized_state
def replay(self, use_prioritized=True) -> float:
"""
Train on a batch of experiences with prioritized sampling
Args:
use_prioritized: Whether to use prioritized replay
Returns:
float: Loss value
"""
if len(self.memory) < self.batch_size:
return 0.0
# Sample batch with prioritization
batch = []
if use_prioritized and len(self.positive_memory) > 0 and len(self.extrema_memory) > 0:
# Prioritized sampling from different memory types
positive_count = min(self.batch_size // 4, len(self.positive_memory))
extrema_count = min(self.batch_size // 4, len(self.extrema_memory))
regular_count = self.batch_size - positive_count - extrema_count
positive_samples = random.sample(list(self.positive_memory), positive_count)
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
regular_samples = random.sample(list(self.memory), regular_count)
batch = positive_samples + extrema_samples + regular_samples
else:
# Standard sampling
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
# Normalize states before training
normalized_states = np.array([self._normalize_state(state) for state in states])
normalized_next_states = np.array([self._normalize_state(state) for state in next_states])
# Convert to tensors and move to device
states_tensor = torch.FloatTensor(normalized_states).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device)
dones_tensor = torch.FloatTensor(dones).to(self.device)
# Get current Q values
current_q_values, extrema_pred = self.policy_net(states_tensor)
current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1))
# Get next Q values from target network (Double DQN approach)
with torch.no_grad():
# Get actions from policy network
next_actions, _ = self.policy_net(next_states_tensor)
next_actions = next_actions.max(1)[1].unsqueeze(1)
# Get Q values from target network for those actions
next_q_values, _ = self.target_net(next_states_tensor)
next_q_values = next_q_values.gather(1, next_actions).squeeze(1)
# Compute target Q values
target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values
# Clamp target values to prevent extreme values
target_q_values = torch.clamp(target_q_values, -100, 100)
# Compute Huber loss (more robust to outliers than MSE)
loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
# Apply gradient clipping to prevent exploding gradients
nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training method focused on extrema patterns
Args:
states: Array of states near extrema points
actions: Correct actions to take (buy at bottoms, sell at tops)
rewards: Rewards for each action
next_states: Next states
dones: Done flags
"""
if len(states) == 0:
return 0.0
# Normalize states
normalized_states = np.array([self._normalize_state(state) for state in states])
normalized_next_states = np.array([self._normalize_state(state) for state in next_states])
# Convert to tensors
states_tensor = torch.FloatTensor(normalized_states).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device)
dones_tensor = torch.FloatTensor(dones).to(self.device)
# Forward pass
current_q_values, extrema_pred = self.policy_net(states_tensor)
current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1))
# Get next Q values (Double DQN approach)
with torch.no_grad():
next_actions, _ = self.policy_net(next_states_tensor)
next_actions = next_actions.max(1)[1].unsqueeze(1)
next_q_values, _ = self.target_net(next_states_tensor)
next_q_values = next_q_values.gather(1, next_actions).squeeze(1)
target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values
# Clamp target values
target_q_values = torch.clamp(target_q_values, -100, 100)
# Use Huber loss for extrema training
q_loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values)
# Full loss
loss = q_loss
# Optimize
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
return loss.item()
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
"""Update learning metrics and perform learning rate adjustments if needed"""
# Update average reward with exponential moving average