misc
This commit is contained in:
@ -54,6 +54,7 @@ class DQNAgent:
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
@ -127,6 +128,28 @@ class DQNAgent:
|
||||
self.best_reward = -float('inf')
|
||||
self.no_improvement_count = 0
|
||||
|
||||
# Confidence tracking
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Trade action fee and confidence thresholds
|
||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||
self.minimum_action_confidence = 0.5 # Minimum confidence to consider trading
|
||||
self.recent_actions = [] # Track recent actions to avoid oscillations
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
self.volatility_window = 20 # Window size for volatility calculation
|
||||
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
||||
self.post_violent_move = False # Flag for recent violent move
|
||||
self.violent_move_cooldown = 0 # Cooldown after violent move
|
||||
|
||||
# Feature integration
|
||||
self.last_hidden_features = None # Store last extracted features
|
||||
self.feature_history = [] # Store history of features for analysis
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
self.use_mixed_precision = False
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
@ -146,6 +169,7 @@ class DQNAgent:
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0]] # Default timeframes
|
||||
|
||||
logger.info(f"DQN Agent using device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
@ -189,8 +213,20 @@ class DQNAgent:
|
||||
current_price = state[-1] # Last feature
|
||||
next_price = next_state[-1]
|
||||
|
||||
# Calculate price change
|
||||
price_change = (next_price - current_price) / current_price
|
||||
# Calculate price change - avoid division by zero
|
||||
if np.isscalar(current_price) and current_price != 0:
|
||||
price_change = (next_price - current_price) / current_price
|
||||
elif isinstance(current_price, np.ndarray):
|
||||
# Handle array case - protect against division by zero
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
price_change = (next_price - current_price) / current_price
|
||||
# Replace infinities and NaNs with zeros
|
||||
if isinstance(price_change, np.ndarray):
|
||||
price_change = np.nan_to_num(price_change, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
else:
|
||||
price_change = 0.0 if np.isnan(price_change) or np.isinf(price_change) else price_change
|
||||
else:
|
||||
price_change = 0.0
|
||||
|
||||
# Check if this is a significant price movement
|
||||
if abs(price_change) > 0.002: # Significant price change
|
||||
@ -264,9 +300,17 @@ class DQNAgent:
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred, price_predictions = self.policy_net(state_tensor)
|
||||
action_probs, extrema_pred, price_predictions, hidden_features = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Store hidden features for integration
|
||||
self.last_hidden_features = hidden_features.cpu().numpy()
|
||||
|
||||
# Track feature history (limited size)
|
||||
self.feature_history.append(hidden_features.cpu().numpy())
|
||||
if len(self.feature_history) > 100:
|
||||
self.feature_history = self.feature_history[-100:]
|
||||
|
||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
||||
extrema_class = extrema_pred.argmax(dim=1).item()
|
||||
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
|
||||
@ -336,17 +380,120 @@ class DQNAgent:
|
||||
# Get the action with highest Q-value
|
||||
action = action_probs.argmax().item()
|
||||
|
||||
# Calculate overall confidence in the action
|
||||
q_values_softmax = F.softmax(action_probs, dim=1)[0]
|
||||
action_confidence = q_values_softmax[action].item()
|
||||
|
||||
# Track confidence metrics
|
||||
self.confidence_history.append(action_confidence)
|
||||
if len(self.confidence_history) > 100:
|
||||
self.confidence_history = self.confidence_history[-100:]
|
||||
|
||||
# Update confidence metrics
|
||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
||||
self.max_confidence = max(self.max_confidence, action_confidence)
|
||||
self.min_confidence = min(self.min_confidence, action_confidence)
|
||||
|
||||
# Log average confidence occasionally
|
||||
if random.random() < 0.01: # 1% of the time
|
||||
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
||||
|
||||
# Track price for violent move detection
|
||||
try:
|
||||
# Extract current price from state (assuming it's in the last position)
|
||||
if len(state.shape) > 1: # For 2D state
|
||||
current_price = state[-1, -1]
|
||||
else: # For 1D state
|
||||
current_price = state[-1]
|
||||
|
||||
self.price_history.append(current_price)
|
||||
if len(self.price_history) > self.volatility_window:
|
||||
self.price_history = self.price_history[-self.volatility_window:]
|
||||
|
||||
# Detect violent price moves if we have enough price history
|
||||
if len(self.price_history) >= 5:
|
||||
# Calculate short-term volatility
|
||||
recent_prices = self.price_history[-5:]
|
||||
|
||||
# Make sure we're working with scalar values, not arrays
|
||||
if isinstance(recent_prices[0], np.ndarray):
|
||||
# If prices are arrays, extract the last value (current price)
|
||||
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
|
||||
|
||||
# Calculate price changes with protection against division by zero
|
||||
price_changes = []
|
||||
for i in range(1, len(recent_prices)):
|
||||
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
|
||||
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
||||
price_changes.append(change)
|
||||
else:
|
||||
price_changes.append(0.0)
|
||||
|
||||
# Calculate volatility as sum of absolute price changes
|
||||
volatility = sum([abs(change) for change in price_changes])
|
||||
|
||||
# Check if we've had a violent move
|
||||
if volatility > self.volatility_threshold:
|
||||
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
|
||||
self.post_violent_move = True
|
||||
self.violent_move_cooldown = 10 # Set cooldown period
|
||||
|
||||
# Handle post-violent move period
|
||||
if self.post_violent_move:
|
||||
if self.violent_move_cooldown > 0:
|
||||
self.violent_move_cooldown -= 1
|
||||
# Increase confidence threshold temporarily after violent moves
|
||||
effective_threshold = self.minimum_action_confidence * 1.1
|
||||
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
|
||||
f"Using higher confidence threshold: {effective_threshold:.4f}")
|
||||
else:
|
||||
self.post_violent_move = False
|
||||
logger.info("Post-violent move period ended")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in violent move detection: {str(e)}")
|
||||
|
||||
# Apply trade action fee to buy/sell actions but not to hold
|
||||
# This creates a threshold that must be exceeded to justify a trade
|
||||
action_values = action_probs.clone()
|
||||
|
||||
# If BUY or SELL, apply fee by reducing the Q-value
|
||||
if action == 0 or action == 1: # BUY or SELL
|
||||
# Check if confidence is above minimum threshold
|
||||
effective_threshold = self.minimum_action_confidence
|
||||
if self.post_violent_move:
|
||||
effective_threshold *= 1.1 # Higher threshold after violent moves
|
||||
|
||||
if action_confidence < effective_threshold:
|
||||
# If confidence is below threshold, force HOLD action
|
||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
else:
|
||||
# Apply trade action fee to ensure we only trade when there's clear benefit
|
||||
fee_adjusted_action_values = action_values.clone()
|
||||
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
|
||||
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
|
||||
# Hold value remains unchanged
|
||||
|
||||
# Re-determine the action based on fee-adjusted values
|
||||
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
|
||||
|
||||
# If the fee changes our decision, log this
|
||||
if fee_adjusted_action != action:
|
||||
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
|
||||
action = fee_adjusted_action
|
||||
|
||||
# Adjust action based on extrema and price predictions
|
||||
# Prioritize short-term movement for trading decisions
|
||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
||||
if immediate_direction == 2: # UP prediction
|
||||
# Bias toward BUY for strong up predictions
|
||||
if action != 0 and random.random() < 0.3 * immediate_conf:
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
||||
action = 0 # BUY
|
||||
elif immediate_direction == 0: # DOWN prediction
|
||||
# Bias toward SELL for strong down predictions
|
||||
if action != 1 and random.random() < 0.3 * immediate_conf:
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
||||
action = 1 # SELL
|
||||
|
||||
@ -354,333 +501,217 @@ class DQNAgent:
|
||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
||||
if extrema_class == 0: # Bottom detected
|
||||
# Bias toward BUY at bottoms
|
||||
if action != 0 and random.random() < 0.3 * extrema_confidence:
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to BUY based on bottom detection")
|
||||
action = 0 # BUY
|
||||
elif extrema_class == 1: # Top detected
|
||||
# Bias toward SELL at tops
|
||||
if action != 1 and random.random() < 0.3 * extrema_confidence:
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to SELL based on top detection")
|
||||
action = 1 # SELL
|
||||
|
||||
# Finally, avoid action oscillation by checking recent history
|
||||
if len(self.recent_actions) >= 2:
|
||||
last_action = self.recent_actions[-1]
|
||||
if action != last_action and action != 2 and last_action != 2:
|
||||
# We're switching between BUY and SELL too quickly
|
||||
# Only allow this if we have very high confidence
|
||||
if action_confidence < 0.85:
|
||||
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
|
||||
# Update recent actions list
|
||||
self.recent_actions.append(action)
|
||||
if len(self.recent_actions) > 5:
|
||||
self.recent_actions = self.recent_actions[-5:]
|
||||
|
||||
return action
|
||||
|
||||
def replay(self, use_prioritized=True) -> float:
|
||||
"""Experience replay - learn from stored experiences
|
||||
|
||||
Args:
|
||||
use_prioritized: Whether to use prioritized experience replay
|
||||
|
||||
Returns:
|
||||
float: Training loss
|
||||
"""
|
||||
# Check if we have enough samples
|
||||
if len(self.memory) < self.batch_size:
|
||||
def replay(self, experiences=None):
|
||||
"""Train the model using experiences from memory"""
|
||||
|
||||
# Don't train if not in training mode
|
||||
if not self.training:
|
||||
return 0.0
|
||||
|
||||
# Check if mixed precision should be disabled
|
||||
if 'DISABLE_MIXED_PRECISION' in os.environ:
|
||||
self.use_mixed_precision = False
|
||||
# If no experiences provided, sample from memory
|
||||
if experiences is None:
|
||||
# Skip if memory is too small
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample from memory with or without prioritization
|
||||
if use_prioritized and len(self.positive_memory) > self.batch_size // 4:
|
||||
# Use prioritized sampling: mix normal samples with positive reward samples
|
||||
positive_batch_size = min(self.batch_size // 4, len(self.positive_memory))
|
||||
regular_batch_size = self.batch_size - positive_batch_size
|
||||
|
||||
# Get positive examples
|
||||
positive_batch = random.sample(self.positive_memory, positive_batch_size)
|
||||
|
||||
# Get regular examples
|
||||
regular_batch = random.sample(self.memory, regular_batch_size)
|
||||
|
||||
# Combine batches
|
||||
minibatch = positive_batch + regular_batch
|
||||
else:
|
||||
# Use regular uniform sampling
|
||||
minibatch = random.sample(self.memory, self.batch_size)
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
experiences = [self.memory[i] for i in indices]
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
states = np.vstack([self._normalize_state(x[0]) for x in minibatch])
|
||||
actions = np.array([x[1] for x in minibatch])
|
||||
rewards = np.array([x[2] for x in minibatch])
|
||||
next_states = np.vstack([self._normalize_state(x[3]) for x in minibatch])
|
||||
dones = np.array([x[4] for x in minibatch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
states_tensor = torch.FloatTensor(states).to(self.device)
|
||||
actions_tensor = torch.LongTensor(actions).to(self.device)
|
||||
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
|
||||
dones_tensor = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# First training step with mixed precision if available
|
||||
# Choose appropriate replay method
|
||||
if self.use_mixed_precision:
|
||||
loss = self._replay_mixed_precision(
|
||||
states_tensor, actions_tensor, rewards_tensor,
|
||||
next_states_tensor, dones_tensor
|
||||
)
|
||||
# Convert experiences to tensors for mixed precision
|
||||
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
|
||||
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
|
||||
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
||||
|
||||
# Use mixed precision replay
|
||||
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
|
||||
else:
|
||||
loss = self._replay_standard(
|
||||
states_tensor, actions_tensor, rewards_tensor,
|
||||
next_states_tensor, dones_tensor
|
||||
)
|
||||
# Pass experiences directly to standard replay method
|
||||
loss = self._replay_standard(experiences)
|
||||
|
||||
# Training focus selector - randomly focus on one of the specialized training types
|
||||
training_focus = random.random()
|
||||
|
||||
# Occasionally train specifically on extrema points
|
||||
if training_focus < 0.3 and hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
|
||||
# Sample from extrema memory
|
||||
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
|
||||
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
|
||||
extrema_actions = np.array([x[1] for x in extrema_batch])
|
||||
extrema_rewards = np.array([x[2] for x in extrema_batch])
|
||||
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
|
||||
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
|
||||
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
|
||||
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
|
||||
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
|
||||
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on extrema points (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate for fine-tuning on extrema
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Train on extrema
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + extrema_loss) / 2
|
||||
|
||||
# Occasionally train specifically on price movement data
|
||||
elif training_focus >= 0.3 and training_focus < 0.6 and hasattr(self, 'price_movement_memory') and len(self.price_movement_memory) >= self.batch_size // 2:
|
||||
# Sample from price movement memory
|
||||
price_batch_size = min(self.batch_size // 2, len(self.price_movement_memory))
|
||||
price_batch = random.sample(self.price_movement_memory, price_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
price_states = np.vstack([self._normalize_state(x[0]) for x in price_batch])
|
||||
price_actions = np.array([x[1] for x in price_batch])
|
||||
price_rewards = np.array([x[2] for x in price_batch])
|
||||
price_next_states = np.vstack([self._normalize_state(x[3]) for x in price_batch])
|
||||
price_dones = np.array([x[4] for x in price_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
price_states_tensor = torch.FloatTensor(price_states).to(self.device)
|
||||
price_actions_tensor = torch.LongTensor(price_actions).to(self.device)
|
||||
price_rewards_tensor = torch.FloatTensor(price_rewards).to(self.device)
|
||||
price_next_states_tensor = torch.FloatTensor(price_next_states).to(self.device)
|
||||
price_dones_tensor = torch.FloatTensor(price_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on price movements (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Train on price movement data
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
else:
|
||||
price_loss = self._replay_standard(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on price movement data: loss={price_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + price_loss) / 2
|
||||
|
||||
# Store and return loss
|
||||
# Store loss for monitoring
|
||||
self.losses.append(loss)
|
||||
return loss
|
||||
|
||||
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
||||
"""Standard precision training step"""
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
# Log the shape mismatch for debugging
|
||||
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
||||
# Use the smaller size to prevent index errors
|
||||
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
||||
rewards = rewards[:min_size]
|
||||
dones = dones[:min_size]
|
||||
next_q_values = next_q_values[:min_size]
|
||||
current_q_values = current_q_values[:min_size]
|
||||
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
|
||||
# Initialize combined loss with Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(states.shape) == 3: # [batch, seq, features]
|
||||
current_prices = states[:, -1, -1] # Last timestep, last feature
|
||||
next_prices = next_states[:, -1, -1]
|
||||
else: # [batch, features]
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Training losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}, "
|
||||
f"Imm-loss={immediate_loss.item():.4f}, "
|
||||
f"Mid-loss={midterm_loss.item():.4f}, "
|
||||
f"Long-loss={longterm_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping to prevent exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Track and decay epsilon
|
||||
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||
|
||||
return loss.item()
|
||||
# Randomly decide if we should train on extrema points from special memory
|
||||
if random.random() < 0.3 and len(self.extrema_memory) >= self.batch_size:
|
||||
# Train specifically on extrema memory examples
|
||||
extrema_indices = np.random.choice(len(self.extrema_memory), size=min(self.batch_size, len(self.extrema_memory)), replace=False)
|
||||
extrema_batch = [self.extrema_memory[i] for i in extrema_indices]
|
||||
|
||||
# Extract tensors from extrema batch
|
||||
extrema_states = torch.FloatTensor(np.array([e[0] for e in extrema_batch])).to(self.device)
|
||||
extrema_actions = torch.LongTensor(np.array([e[1] for e in extrema_batch])).to(self.device)
|
||||
extrema_rewards = torch.FloatTensor(np.array([e[2] for e in extrema_batch])).to(self.device)
|
||||
extrema_next_states = torch.FloatTensor(np.array([e[3] for e in extrema_batch])).to(self.device)
|
||||
extrema_dones = torch.FloatTensor(np.array([e[4] for e in extrema_batch])).to(self.device)
|
||||
|
||||
# Use a slightly reduced learning rate for extrema training
|
||||
old_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr * 0.8
|
||||
|
||||
# Train on extrema memory
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(extrema_states, extrema_actions, extrema_rewards, extrema_next_states, extrema_dones)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(extrema_batch)
|
||||
|
||||
# Reset learning rate
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr
|
||||
|
||||
# Log extrema loss
|
||||
logger.info(f"Extra training on extrema points, loss: {extrema_loss:.4f}")
|
||||
|
||||
# Randomly train on price movement examples (similar to extrema)
|
||||
if random.random() < 0.3 and len(self.price_movement_memory) >= self.batch_size:
|
||||
# Train specifically on price movement memory examples
|
||||
price_indices = np.random.choice(len(self.price_movement_memory), size=min(self.batch_size, len(self.price_movement_memory)), replace=False)
|
||||
price_batch = [self.price_movement_memory[i] for i in price_indices]
|
||||
|
||||
# Extract tensors from price movement batch
|
||||
price_states = torch.FloatTensor(np.array([e[0] for e in price_batch])).to(self.device)
|
||||
price_actions = torch.LongTensor(np.array([e[1] for e in price_batch])).to(self.device)
|
||||
price_rewards = torch.FloatTensor(np.array([e[2] for e in price_batch])).to(self.device)
|
||||
price_next_states = torch.FloatTensor(np.array([e[3] for e in price_batch])).to(self.device)
|
||||
price_dones = torch.FloatTensor(np.array([e[4] for e in price_batch])).to(self.device)
|
||||
|
||||
# Use a slightly reduced learning rate for price movement training
|
||||
old_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr * 0.75
|
||||
|
||||
# Train on price movement memory
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(price_states, price_actions, price_rewards, price_next_states, price_dones)
|
||||
else:
|
||||
price_loss = self._replay_standard(price_batch)
|
||||
|
||||
# Reset learning rate
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr
|
||||
|
||||
# Log price movement loss
|
||||
logger.info(f"Extra training on price movement examples, loss: {price_loss:.4f}")
|
||||
|
||||
return loss
|
||||
|
||||
def _replay_standard(self, experiences=None):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
# Use experiences if provided, otherwise sample from memory
|
||||
if experiences is None:
|
||||
# If memory is too small, skip training
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
batch = [self.memory[i] for i in indices]
|
||||
experiences = batch
|
||||
|
||||
# Unpack experiences
|
||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values with target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch between rewards and next_q_values
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
||||
# Use the smaller size to prevent index error
|
||||
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
||||
rewards = rewards[:min_size]
|
||||
dones = dones[:min_size]
|
||||
next_q_values = next_q_values[:min_size]
|
||||
current_q_values = current_q_values[:min_size]
|
||||
|
||||
# Calculate target Q values
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute loss for Q value
|
||||
q_loss = self.criterion(current_q_values, target_q_values)
|
||||
|
||||
# Try to compute extrema loss if possible
|
||||
try:
|
||||
# Get the target classes from extrema predictions
|
||||
extrema_targets = torch.argmax(current_extrema_pred, dim=1).long()
|
||||
|
||||
# Compute extrema loss using cross-entropy - this is an auxiliary task
|
||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
||||
|
||||
# Combined loss with emphasis on Q-learning
|
||||
total_loss = q_loss + 0.1 * extrema_loss
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
total_loss = q_loss
|
||||
|
||||
# Reset gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Clip gradients to avoid exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
|
||||
# Update weights
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Return loss
|
||||
return total_loss.item()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in replay standard: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 0.0
|
||||
|
||||
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
|
||||
"""Mixed precision training step for better GPU performance"""
|
||||
@ -696,12 +727,12 @@ class DQNAgent:
|
||||
# Forward pass with amp autocasting
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
@ -733,7 +764,7 @@ class DQNAgent:
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes for different timeframes
|
||||
# Calculate price change for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create price direction labels - simplified for training
|
||||
|
Reference in New Issue
Block a user