gogo2/NN/models/dqn_agent.py
2025-05-26 16:02:40 +03:00

1171 lines
59 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
import os
import sys
import logging
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# Configure logger
logger = logging.getLogger(__name__)
class DQNAgent:
"""
Deep Q-Network agent for trading
Uses Enhanced CNN model as the base network with GPU support for improved performance
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int,
learning_rate: float = 0.0005, # Reduced learning rate for more stability
gamma: float = 0.97, # Slightly reduced discount factor
epsilon: float = 1.0,
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
epsilon_decay: float = 0.9975, # Slower decay rate
buffer_size: int = 20000, # Increased memory size
batch_size: int = 128, # Larger batch size
target_update: int = 5, # More frequent target updates
device=None): # Device for computations
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
# Multi-dimensional state (like image or sequence)
self.state_dim = state_shape
else:
# 1D state
if isinstance(state_shape, tuple):
self.state_dim = state_shape[0]
else:
self.state_dim = state_shape
# Store parameters
self.n_actions = n_actions
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
self.buffer_size = buffer_size
self.batch_size = batch_size
self.target_update = target_update
# Set device for computation (default to GPU if available)
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# Initialize models with Enhanced CNN architecture for better performance
from NN.models.enhanced_cnn import EnhancedCNN
# Use Enhanced CNN for both policy and target networks
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
# Set models to eval mode (important for batch norm, dropout)
self.target_net.eval()
# Optimization components
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
self.criterion = nn.MSELoss()
# Experience replay memory
self.memory = []
self.positive_memory = [] # Special memory for storing good experiences
self.update_count = 0
# Extrema detection tracking
self.last_extrema_pred = {
'class': 2, # Default to "neither" (not extrema)
'confidence': 0.0,
'raw': None
}
self.extrema_memory = [] # Special memory for storing extrema points
# Price prediction tracking
self.last_price_pred = {
'immediate': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'midterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'longterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
}
}
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking
self.losses = []
self.avg_reward = 0.0
self.best_reward = -float('inf')
self.no_improvement_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
self.recent_actions = [] # Track recent actions to avoid oscillations
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
self.use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
self.device = device
try:
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
logger.info(f"Moved models to {self.device}")
return True
except Exception as e:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""
Store experience in memory with prioritization
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
done: Whether episode is done
is_extrema: Whether this is a local extrema sample (for specialized learning)
"""
experience = (state, action, reward, next_state, done)
# Always add to main memory
self.memory.append(experience)
# Try to extract price change to analyze the experience
try:
# Extract price feature from sequence data (if available)
if len(state.shape) > 1: # 2D state [timeframes, features]
current_price = state[-1, -1] # Last timeframe, last feature
next_price = next_state[-1, -1]
else: # 1D state
current_price = state[-1] # Last feature
next_price = next_state[-1]
# Calculate price change - avoid division by zero
if np.isscalar(current_price) and current_price != 0:
price_change = (next_price - current_price) / current_price
elif isinstance(current_price, np.ndarray):
# Handle array case - protect against division by zero
with np.errstate(divide='ignore', invalid='ignore'):
price_change = (next_price - current_price) / current_price
# Replace infinities and NaNs with zeros
if isinstance(price_change, np.ndarray):
price_change = np.nan_to_num(price_change, nan=0.0, posinf=0.0, neginf=0.0)
else:
price_change = 0.0 if np.isnan(price_change) or np.isinf(price_change) else price_change
else:
price_change = 0.0
# Check if this is a significant price movement
if abs(price_change) > 0.002: # Significant price change
# Store in price movement memory
self.price_movement_memory.append(experience)
# Log significant price movements
direction = "UP" if price_change > 0 else "DOWN"
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
# For clear price movements, also duplicate in main memory to learn more
if abs(price_change) > 0.005: # Very significant movement
for _ in range(2): # Add 2 extra copies
self.memory.append(experience)
except Exception as e:
# Skip price movement analysis if it fails
pass
# Check if this is an extrema point based on our extrema detection head
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
# Class 0 = bottom, 1 = top, 2 = neither
# Only consider high confidence predictions
if self.last_extrema_pred['confidence'] > 0.7:
self.extrema_memory.append(experience)
# Log this special experience
extrema_type = "BOTTOM" if self.last_extrema_pred['class'] == 0 else "TOP"
logger.info(f"Stored {extrema_type} experience with reward {reward:.4f}")
# For tops and bottoms, also duplicate the experience in memory to learn more from it
for _ in range(2): # Add 2 extra copies
self.memory.append(experience)
# Explicitly marked extrema points also go to extrema memory
elif is_extrema:
self.extrema_memory.append(experience)
# Store positive experiences separately for prioritized replay
if reward > 0:
self.positive_memory.append(experience)
# For very good rewards, duplicate to learn more from them
if reward > 0.1:
for _ in range(min(int(reward * 10), 5)): # Cap at 5 extra copies for very high rewards
self.positive_memory.append(experience)
# Keep memory size under control
if len(self.memory) > self.buffer_size:
# Keep more recent experiences
self.memory = self.memory[-self.buffer_size:]
# Keep specialized memories under control too
if len(self.positive_memory) > self.buffer_size // 4:
self.positive_memory = self.positive_memory[-(self.buffer_size // 4):]
if len(self.extrema_memory) > self.buffer_size // 4:
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
if len(self.price_movement_memory) > self.buffer_size // 4:
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
def act(self, state: np.ndarray, explore=True) -> int:
"""Choose action using epsilon-greedy policy with explore flag"""
if explore and random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad():
# Enhance state with real-time tick features
enhanced_state = self._enhance_state_with_tick_features(state)
# Ensure state is normalized before inference
state_tensor = self._normalize_state(enhanced_state)
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
# Get predictions using the policy network
self.policy_net.eval() # Set to evaluation mode for inference
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
self.policy_net.train() # Back to training mode
# Store hidden features for integration
self.last_hidden_features = hidden_features.cpu().numpy()
# Track feature history (limited size)
self.feature_history.append(hidden_features.cpu().numpy())
if len(self.feature_history) > 100:
self.feature_history = self.feature_history[-100:]
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
extrema_class = extrema_pred.argmax(dim=1).item()
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
# Log extrema prediction for significant signals
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
# Process price predictions
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
price_values = price_predictions['values']
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
immediate_direction = price_immediate.argmax(dim=1).item()
midterm_direction = price_midterm.argmax(dim=1).item()
longterm_direction = price_longterm.argmax(dim=1).item()
# Get confidence levels
immediate_conf = price_immediate[0, immediate_direction].item()
midterm_conf = price_midterm[0, midterm_direction].item()
longterm_conf = price_longterm[0, longterm_direction].item()
# Get predicted price change percentages
price_changes = price_values[0].tolist()
# Log significant price movement predictions
timeframes = ["1s/1m", "1h", "1d", "1w"]
directions = ["DOWN", "SIDEWAYS", "UP"]
for i, (direction, conf) in enumerate([
(immediate_direction, immediate_conf),
(midterm_direction, midterm_conf),
(longterm_direction, longterm_conf)
]):
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
# Store predictions for environment to use
self.last_extrema_pred = {
'class': extrema_class,
'confidence': extrema_confidence,
'raw': extrema_pred.cpu().numpy()
}
self.last_price_pred = {
'immediate': {
'direction': immediate_direction,
'confidence': immediate_conf,
'change': price_changes[0]
},
'midterm': {
'direction': midterm_direction,
'confidence': midterm_conf,
'change': price_changes[1]
},
'longterm': {
'direction': longterm_direction,
'confidence': longterm_conf,
'change': price_changes[2]
}
}
# Get the action with highest Q-value
action = action_probs.argmax().item()
# Calculate overall confidence in the action
q_values_softmax = F.softmax(action_probs, dim=1)[0]
action_confidence = q_values_softmax[action].item()
# Track confidence metrics
self.confidence_history.append(action_confidence)
if len(self.confidence_history) > 100:
self.confidence_history = self.confidence_history[-100:]
# Update confidence metrics
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
self.max_confidence = max(self.max_confidence, action_confidence)
self.min_confidence = min(self.min_confidence, action_confidence)
# Log average confidence occasionally
if random.random() < 0.01: # 1% of the time
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
# Track price for violent move detection
try:
# Extract current price from state (assuming it's in the last position)
if len(state.shape) > 1: # For 2D state
current_price = state[-1, -1]
else: # For 1D state
current_price = state[-1]
self.price_history.append(current_price)
if len(self.price_history) > self.volatility_window:
self.price_history = self.price_history[-self.volatility_window:]
# Detect violent price moves if we have enough price history
if len(self.price_history) >= 5:
# Calculate short-term volatility
recent_prices = self.price_history[-5:]
# Make sure we're working with scalar values, not arrays
if isinstance(recent_prices[0], np.ndarray):
# If prices are arrays, extract the last value (current price)
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
# Calculate price changes with protection against division by zero
price_changes = []
for i in range(1, len(recent_prices)):
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
price_changes.append(change)
else:
price_changes.append(0.0)
# Calculate volatility as sum of absolute price changes
volatility = sum([abs(change) for change in price_changes])
# Check if we've had a violent move
if volatility > self.volatility_threshold:
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
self.post_violent_move = True
self.violent_move_cooldown = 10 # Set cooldown period
# Handle post-violent move period
if self.post_violent_move:
if self.violent_move_cooldown > 0:
self.violent_move_cooldown -= 1
# Increase confidence threshold temporarily after violent moves
effective_threshold = self.minimum_action_confidence * 1.1
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
f"Using higher confidence threshold: {effective_threshold:.4f}")
else:
self.post_violent_move = False
logger.info("Post-violent move period ended")
except Exception as e:
logger.warning(f"Error in violent move detection: {str(e)}")
# Apply trade action fee to buy/sell actions but not to hold
# This creates a threshold that must be exceeded to justify a trade
action_values = action_probs.clone()
# If BUY or SELL, apply fee by reducing the Q-value
if action == 0 or action == 1: # BUY or SELL
# Check if confidence is above minimum threshold
effective_threshold = self.minimum_action_confidence
if self.post_violent_move:
effective_threshold *= 1.1 # Higher threshold after violent moves
if action_confidence < effective_threshold:
# If confidence is below threshold, force HOLD action
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
action = 2 # HOLD
else:
# Apply trade action fee to ensure we only trade when there's clear benefit
fee_adjusted_action_values = action_values.clone()
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
# Hold value remains unchanged
# Re-determine the action based on fee-adjusted values
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
# If the fee changes our decision, log this
if fee_adjusted_action != action:
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
action = fee_adjusted_action
# Adjust action based on extrema and price predictions
# Prioritize short-term movement for trading decisions
if immediate_conf > 0.8: # Only adjust for strong signals
if immediate_direction == 2: # UP prediction
# Bias toward BUY for strong up predictions
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
action = 0 # BUY
elif immediate_direction == 0: # DOWN prediction
# Bias toward SELL for strong down predictions
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
action = 1 # SELL
# Also consider extrema detection for action adjustment
if extrema_confidence > 0.8: # Only adjust for strong signals
if extrema_class == 0: # Bottom detected
# Bias toward BUY at bottoms
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
logger.info(f"Adjusting action to BUY based on bottom detection")
action = 0 # BUY
elif extrema_class == 1: # Top detected
# Bias toward SELL at tops
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
logger.info(f"Adjusting action to SELL based on top detection")
action = 1 # SELL
# Finally, avoid action oscillation by checking recent history
if len(self.recent_actions) >= 2:
last_action = self.recent_actions[-1]
if action != last_action and action != 2 and last_action != 2:
# We're switching between BUY and SELL too quickly
# Only allow this if we have very high confidence
if action_confidence < 0.85:
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
action = 2 # HOLD
# Update recent actions list
self.recent_actions.append(action)
if len(self.recent_actions) > 5:
self.recent_actions = self.recent_actions[-5:]
return action
def replay(self, experiences=None):
"""Train the model using experiences from memory"""
# Don't train if not in training mode
if not self.training:
return 0.0
# If no experiences provided, sample from memory
if experiences is None:
# Skip if memory is too small
if len(self.memory) < self.batch_size:
return 0.0
# Sample random mini-batch from memory
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
experiences = [self.memory[i] for i in indices]
# Choose appropriate replay method
if self.use_mixed_precision:
# Convert experiences to tensors for mixed precision
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
# Use mixed precision replay
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
else:
# Pass experiences directly to standard replay method
loss = self._replay_standard(experiences)
# Store loss for monitoring
self.losses.append(loss)
# Track and decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
# Randomly decide if we should train on extrema points from special memory
if random.random() < 0.3 and len(self.extrema_memory) >= self.batch_size:
# Train specifically on extrema memory examples
extrema_indices = np.random.choice(len(self.extrema_memory), size=min(self.batch_size, len(self.extrema_memory)), replace=False)
extrema_batch = [self.extrema_memory[i] for i in extrema_indices]
# Extract tensors from extrema batch
extrema_states = torch.FloatTensor(np.array([e[0] for e in extrema_batch])).to(self.device)
extrema_actions = torch.LongTensor(np.array([e[1] for e in extrema_batch])).to(self.device)
extrema_rewards = torch.FloatTensor(np.array([e[2] for e in extrema_batch])).to(self.device)
extrema_next_states = torch.FloatTensor(np.array([e[3] for e in extrema_batch])).to(self.device)
extrema_dones = torch.FloatTensor(np.array([e[4] for e in extrema_batch])).to(self.device)
# Use a slightly reduced learning rate for extrema training
old_lr = self.optimizer.param_groups[0]['lr']
self.optimizer.param_groups[0]['lr'] = old_lr * 0.8
# Train on extrema memory
if self.use_mixed_precision:
extrema_loss = self._replay_mixed_precision(extrema_states, extrema_actions, extrema_rewards, extrema_next_states, extrema_dones)
else:
extrema_loss = self._replay_standard(extrema_batch)
# Reset learning rate
self.optimizer.param_groups[0]['lr'] = old_lr
# Log extrema loss
logger.info(f"Extra training on extrema points, loss: {extrema_loss:.4f}")
# Randomly train on price movement examples (similar to extrema)
if random.random() < 0.3 and len(self.price_movement_memory) >= self.batch_size:
# Train specifically on price movement memory examples
price_indices = np.random.choice(len(self.price_movement_memory), size=min(self.batch_size, len(self.price_movement_memory)), replace=False)
price_batch = [self.price_movement_memory[i] for i in price_indices]
# Extract tensors from price movement batch
price_states = torch.FloatTensor(np.array([e[0] for e in price_batch])).to(self.device)
price_actions = torch.LongTensor(np.array([e[1] for e in price_batch])).to(self.device)
price_rewards = torch.FloatTensor(np.array([e[2] for e in price_batch])).to(self.device)
price_next_states = torch.FloatTensor(np.array([e[3] for e in price_batch])).to(self.device)
price_dones = torch.FloatTensor(np.array([e[4] for e in price_batch])).to(self.device)
# Use a slightly reduced learning rate for price movement training
old_lr = self.optimizer.param_groups[0]['lr']
self.optimizer.param_groups[0]['lr'] = old_lr * 0.75
# Train on price movement memory
if self.use_mixed_precision:
price_loss = self._replay_mixed_precision(price_states, price_actions, price_rewards, price_next_states, price_dones)
else:
price_loss = self._replay_standard(price_batch)
# Reset learning rate
self.optimizer.param_groups[0]['lr'] = old_lr
# Log price movement loss
logger.info(f"Extra training on price movement examples, loss: {price_loss:.4f}")
return loss
def _replay_standard(self, experiences=None):
"""Standard training step without mixed precision"""
try:
# Use experiences if provided, otherwise sample from memory
if experiences is None:
# If memory is too small, skip training
if len(self.memory) < self.batch_size:
return 0.0
# Sample random mini-batch from memory
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
batch = [self.memory[i] for i in indices]
experiences = batch
# Unpack experiences
states, actions, rewards, next_states, dones = zip(*experiences)
# Convert to PyTorch tensors
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(np.array(actions)).to(self.device)
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(np.array(dones)).to(self.device)
# Get current Q values
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values with target network
with torch.no_grad():
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch between rewards and next_q_values
if rewards.shape[0] != next_q_values.shape[0]:
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
# Use the smaller size to prevent index error
min_size = min(rewards.shape[0], next_q_values.shape[0])
rewards = rewards[:min_size]
dones = dones[:min_size]
next_q_values = next_q_values[:min_size]
current_q_values = current_q_values[:min_size]
# Calculate target Q values
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss for Q value
q_loss = self.criterion(current_q_values, target_q_values)
# Try to compute extrema loss if possible
try:
# Get the target classes from extrema predictions
extrema_targets = torch.argmax(current_extrema_pred, dim=1).long()
# Compute extrema loss using cross-entropy - this is an auxiliary task
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
# Combined loss with emphasis on Q-learning
total_loss = q_loss + 0.1 * extrema_loss
except Exception as e:
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
total_loss = q_loss
# Reset gradients
self.optimizer.zero_grad()
# Backward pass
total_loss.backward()
# Clip gradients to avoid exploding gradients
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
# Update weights
self.optimizer.step()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Return loss
return total_loss.item()
except Exception as e:
logger.error(f"Error in replay standard: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return 0.0
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
"""Mixed precision training step for better GPU performance"""
# Check if mixed precision should be explicitly disabled
if 'DISABLE_MIXED_PRECISION' in os.environ:
logger.info("Mixed precision explicitly disabled by environment variable")
return self._replay_standard(states, actions, rewards, next_states, dones)
try:
# Zero gradients
self.optimizer.zero_grad()
# Forward pass with amp autocasting
with torch.cuda.amp.autocast():
# Get current Q values and extrema predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network
with torch.no_grad():
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch and fix it
if rewards.shape[0] != next_q_values.shape[0]:
# Log the shape mismatch for debugging
logger.warning(f"Shape mismatch detected: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
# Use the smaller size to prevent index errors
min_size = min(rewards.shape[0], next_q_values.shape[0])
rewards = rewards[:min_size]
dones = dones[:min_size]
next_q_values = next_q_values[:min_size]
current_q_values = current_q_values[:min_size]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute Q-value loss (primary task)
q_loss = nn.MSELoss()(current_q_values, target_q_values)
# Initialize loss with q_loss
loss = q_loss
# Try to extract price from current and next states
try:
# Extract price feature from sequence data (if available)
if len(states.shape) == 3: # [batch, seq, features]
current_prices = states[:, -1, -1] # Last timestep, last feature
next_prices = next_states[:, -1, -1]
else: # [batch, features]
current_prices = states[:, -1] # Last feature
next_prices = next_states[:, -1]
# Calculate price change for different timeframes
immediate_changes = (next_prices - current_prices) / current_prices
# Get the actual batch size for this calculation
actual_batch_size = states.shape[0]
# Create price direction labels - simplified for training
# 0 = down, 1 = sideways, 2 = up
immediate_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
midterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
longterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
# Immediate term direction (1s, 1m)
immediate_up = (immediate_changes > 0.0005)
immediate_down = (immediate_changes < -0.0005)
immediate_labels[immediate_up] = 2 # Up
immediate_labels[immediate_down] = 0 # Down
# For mid and long term, we can only approximate during training
# In a real system, we'd need historical data to validate these
# Here we'll use the immediate term with increasing thresholds as approximation
# Mid-term (1h) - use slightly higher threshold
midterm_up = (immediate_changes > 0.001)
midterm_down = (immediate_changes < -0.001)
midterm_labels[midterm_up] = 2 # Up
midterm_labels[midterm_down] = 0 # Down
# Long-term (1d) - use even higher threshold
longterm_up = (immediate_changes > 0.002)
longterm_down = (immediate_changes < -0.002)
longterm_labels[longterm_up] = 2 # Up
longterm_labels[longterm_down] = 0 # Down
# Generate target values for price change regression
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
price_value_targets = torch.zeros((actual_batch_size, 4), device=self.device)
price_value_targets[:, 0] = immediate_changes
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
# Calculate loss for price direction prediction (classification)
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= actual_batch_size:
# Slice predictions to match the adjusted batch size
immediate_pred = current_price_pred['immediate'][:actual_batch_size]
midterm_pred = current_price_pred['midterm'][:actual_batch_size]
longterm_pred = current_price_pred['longterm'][:actual_batch_size]
price_values_pred = current_price_pred['values'][:actual_batch_size]
# Compute losses for each task
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
# MSE loss for price value regression
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
# Combine all price prediction losses
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
# Create extrema labels (same as before)
extrema_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 2 # Default: neither
# Identify potential bottoms (significant negative change)
bottoms = (immediate_changes < -0.003)
extrema_labels[bottoms] = 0
# Identify potential tops (significant positive change)
tops = (immediate_changes > 0.003)
extrema_labels[tops] = 1
# Calculate extrema prediction loss
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= actual_batch_size:
current_extrema_pred = current_extrema_pred[:actual_batch_size]
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
# Combined loss with all components
# Primary task: Q-value learning (RL objective)
# Secondary tasks: extrema detection and price prediction (supervised objectives)
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
# Log loss components occasionally
if random.random() < 0.01: # Log 1% of the time
logger.info(
f"Mixed precision losses: Q-loss={q_loss.item():.4f}, "
f"Extrema-loss={extrema_loss.item():.4f}, "
f"Price-loss={price_loss.item():.4f}"
)
except Exception as e:
# Fallback if price extraction fails
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
# Just use Q-value loss
loss = q_loss
# Backward pass with scaled gradients
self.scaler.scale(loss).backward()
# Gradient clipping on scaled gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
# Update with scaler
self.scaler.step(self.optimizer)
self.scaler.update()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Track and decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
except Exception as e:
logger.error(f"Error in mixed precision training: {str(e)}")
logger.warning("Falling back to standard precision training")
# Fall back to standard training
return self._replay_standard(states, actions, rewards, next_states, dones)
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training function specifically for extrema points
Args:
states: Batch of states at extrema points
actions: Batch of actions
rewards: Batch of rewards
next_states: Batch of next states
dones: Batch of done flags
Returns:
float: Training loss
"""
# Convert to numpy arrays if not already
if not isinstance(states, np.ndarray):
states = np.array(states)
if not isinstance(actions, np.ndarray):
actions = np.array(actions)
if not isinstance(rewards, np.ndarray):
rewards = np.array(rewards)
if not isinstance(next_states, np.ndarray):
next_states = np.array(next_states)
if not isinstance(dones, np.ndarray):
dones = np.array(dones, dtype=np.float32)
# Normalize states
states = np.vstack([self._normalize_state(s) for s in states])
next_states = np.vstack([self._normalize_state(s) for s in next_states])
# Convert to torch tensors and move to device
states_tensor = torch.FloatTensor(states).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
dones_tensor = torch.FloatTensor(dones).to(self.device)
# Choose training method based on precision mode
if self.use_mixed_precision:
return self._replay_mixed_precision(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
else:
return self._replay_standard(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
"""Normalize the state data to prevent numerical issues"""
# Handle NaN and infinite values
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
# Check if state is 1D array (happens in some environments)
if len(state.shape) == 1:
# If 1D, we need to normalize the whole array
normalized_state = state.copy()
# Convert any timestamp or non-numeric data to float
for i in range(len(normalized_state)):
# Check for timestamp-like objects
if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')):
# Convert timestamp to float (seconds since epoch)
normalized_state[i] = float(normalized_state[i].timestamp())
elif not isinstance(normalized_state[i], (int, float, np.number)):
# Set non-numeric data to 0
normalized_state[i] = 0.0
# Ensure all values are float
normalized_state = normalized_state.astype(np.float32)
# Simple min-max normalization for 1D state
state_min = np.min(normalized_state)
state_max = np.max(normalized_state)
if state_max > state_min:
normalized_state = (normalized_state - state_min) / (state_max - state_min)
return normalized_state
# Handle 2D arrays
normalized_state = np.zeros_like(state, dtype=np.float32)
# Convert any timestamp or non-numeric data to float
for i in range(state.shape[0]):
for j in range(state.shape[1]):
if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')):
# Convert timestamp to float (seconds since epoch)
normalized_state[i, j] = float(state[i, j].timestamp())
elif isinstance(state[i, j], (int, float, np.number)):
normalized_state[i, j] = state[i, j]
else:
# Set non-numeric data to 0
normalized_state[i, j] = 0.0
# Loop through each timeframe's features in the combined state
feature_count = state.shape[1] // len(self.timeframes)
for tf_idx in range(len(self.timeframes)):
start_idx = tf_idx * feature_count
end_idx = start_idx + feature_count
# Extract this timeframe's features
tf_features = normalized_state[:, start_idx:end_idx]
# Normalize OHLCV data by the first close price in the window
# This makes price movements relative rather than absolute
price_idx = 3 # Assuming close price is at index 3
if price_idx < tf_features.shape[1]:
reference_price = np.mean(tf_features[:, price_idx])
if reference_price != 0:
# Normalize price-related columns (OHLC)
for i in range(4): # First 4 columns are OHLC
if i < tf_features.shape[1]:
normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price
# Normalize volume using mean and std
vol_idx = 4 # Assuming volume is at index 4
if vol_idx < tf_features.shape[1]:
vol_mean = np.mean(tf_features[:, vol_idx])
vol_std = np.std(tf_features[:, vol_idx])
if vol_std > 0:
normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std
else:
normalized_state[:, start_idx + vol_idx] = 0
# Other features (technical indicators) - normalize with min-max scaling
for i in range(5, feature_count):
if i < tf_features.shape[1]:
feature_min = np.min(tf_features[:, i])
feature_max = np.max(tf_features[:, i])
if feature_max > feature_min:
normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min)
else:
normalized_state[:, start_idx + i] = 0
return normalized_state
def update_realtime_tick_features(self, tick_features):
"""Update with real-time tick features from tick processor"""
try:
if tick_features is not None:
self.realtime_tick_features = tick_features
# Log high-confidence tick features
if tick_features.get('confidence', 0) > 0.8:
logger.debug(f"High-confidence tick features updated: confidence={tick_features['confidence']:.3f}")
except Exception as e:
logger.error(f"Error updating real-time tick features: {e}")
def _enhance_state_with_tick_features(self, state: np.ndarray) -> np.ndarray:
"""Enhance state with real-time tick features if available"""
try:
if self.realtime_tick_features is None:
return state
# Extract neural features from tick processor
neural_features = self.realtime_tick_features.get('neural_features', np.array([]))
volume_features = self.realtime_tick_features.get('volume_features', np.array([]))
microstructure_features = self.realtime_tick_features.get('microstructure_features', np.array([]))
confidence = self.realtime_tick_features.get('confidence', 0.0)
# Combine tick features - make them compact to match state dimensions
tick_features = np.concatenate([
neural_features[:3] if len(neural_features) >= 3 else np.zeros(3), # Take first 3 neural features
volume_features[:1] if len(volume_features) >= 1 else np.zeros(1), # Take first volume feature
microstructure_features[:1] if len(microstructure_features) >= 1 else np.zeros(1), # Take first microstructure feature
])
# Weight the tick features
weighted_tick_features = tick_features * self.tick_feature_weight
# Enhance the state by adding tick features to each timeframe
if len(state.shape) == 1:
# 1D state - append tick features
enhanced_state = np.concatenate([state, weighted_tick_features])
else:
# 2D state - add tick features to each timeframe row
num_timeframes, num_features = state.shape
# Ensure tick features match the number of original features
if len(weighted_tick_features) != num_features:
# Pad or truncate tick features to match state feature dimension
if len(weighted_tick_features) < num_features:
# Pad with zeros
padded_features = np.zeros(num_features)
padded_features[:len(weighted_tick_features)] = weighted_tick_features
weighted_tick_features = padded_features
else:
# Truncate to match
weighted_tick_features = weighted_tick_features[:num_features]
# Add tick features to the last row (most recent timeframe)
enhanced_state = state.copy()
enhanced_state[-1, :] += weighted_tick_features # Add to last timeframe
return enhanced_state
except Exception as e:
logger.error(f"Error enhancing state with tick features: {e}")
return state
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
"""Update learning metrics and perform learning rate adjustments if needed"""
# Update average reward with exponential moving average
if self.avg_reward == 0:
self.avg_reward = episode_reward
else:
self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward
# Check if we're making sufficient progress
if episode_reward > (1 + best_reward_threshold) * self.best_reward:
self.best_reward = episode_reward
self.no_improvement_count = 0
return True # Improved
else:
self.no_improvement_count += 1
# If no improvement for a while, adjust learning rate
if self.no_improvement_count >= 10:
current_lr = self.optimizer.param_groups[0]['lr']
new_lr = current_lr * 0.5
if new_lr >= 1e-6: # Don't reduce below minimum threshold
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
logger.info(f"Reducing learning rate from {current_lr} to {new_lr}")
self.no_improvement_count = 0
return False # No improvement
def save(self, path: str):
"""Save model and agent state"""
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save policy network
self.policy_net.save(f"{path}_policy")
# Save target network
self.target_net.save(f"{path}_target")
# Save agent state
state = {
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
'optimizer_state': self.optimizer.state_dict(),
'best_reward': self.best_reward,
'avg_reward': self.avg_reward
}
torch.save(state, f"{path}_agent_state.pt")
logger.info(f"Agent state saved to {path}_agent_state.pt")
def load(self, path: str):
"""Load model and agent state"""
# Load policy network
self.policy_net.load(f"{path}_policy")
# Load target network
self.target_net.load(f"{path}_target")
# Load agent state
try:
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
self.epsilon = agent_state['epsilon']
self.update_count = agent_state['update_count']
self.losses = agent_state['losses']
self.optimizer.load_state_dict(agent_state['optimizer_state'])
# Load additional metrics if they exist
if 'best_reward' in agent_state:
self.best_reward = agent_state['best_reward']
if 'avg_reward' in agent_state:
self.avg_reward = agent_state['avg_reward']
logger.info(f"Agent state loaded from {path}_agent_state.pt")
except FileNotFoundError:
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")