Files
gogo2/NN/models/dqn_agent.py

2190 lines
101 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
<<<<<<< HEAD
from typing import Tuple, List
=======
from typing import Tuple, List, Dict, Any
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
import os
import sys
import logging
import torch.nn.functional as F
import time
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# Import checkpoint management
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
from NN.training.model_manager import create_model_manager
# Configure logger
logger = logging.getLogger(__name__)
class DQNNetwork(nn.Module):
"""
Configurable Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
Handles 7850 input features from multi-timeframe, multi-asset data
Architecture is configurable via config.yaml
"""
def __init__(self, input_dim: int, n_actions: int, config: dict = None):
super(DQNNetwork, self).__init__()
# Handle different input dimension formats
if isinstance(input_dim, (tuple, list)):
if len(input_dim) == 1:
self.input_size = input_dim[0]
else:
self.input_size = np.prod(input_dim) # Flatten multi-dimensional input
else:
self.input_size = input_dim
self.n_actions = n_actions
# Get network architecture from config or use defaults
if config and 'network_architecture' in config:
arch_config = config['network_architecture']
feature_layers = arch_config.get('feature_layers', [4096, 3072, 2048, 1536, 1024])
regime_head = arch_config.get('regime_head', [512, 256])
price_direction_head = arch_config.get('price_direction_head', [512, 256])
volatility_head = arch_config.get('volatility_head', [512, 128])
value_head = arch_config.get('value_head', [512, 256])
advantage_head = arch_config.get('advantage_head', [512, 256])
dropout_rate = arch_config.get('dropout_rate', 0.1)
use_layer_norm = arch_config.get('use_layer_norm', True)
else:
# Default reduced architecture (half the original size)
feature_layers = [4096, 3072, 2048, 1536, 1024]
regime_head = [512, 256]
price_direction_head = [512, 256]
volatility_head = [512, 128]
value_head = [512, 256]
advantage_head = [512, 256]
dropout_rate = 0.1
use_layer_norm = True
# Build configurable feature extractor
feature_layers_list = []
prev_size = self.input_size
for layer_size in feature_layers:
feature_layers_list.append(nn.Linear(prev_size, layer_size))
if use_layer_norm:
feature_layers_list.append(nn.LayerNorm(layer_size))
feature_layers_list.append(nn.ReLU(inplace=True))
feature_layers_list.append(nn.Dropout(dropout_rate))
prev_size = layer_size
self.feature_extractor = nn.Sequential(*feature_layers_list)
self.feature_size = feature_layers[-1] # Final feature size
# Build configurable network heads
def build_head_layers(input_size, layer_sizes, output_size):
layers = []
prev_size = input_size
for layer_size in layer_sizes:
layers.append(nn.Linear(prev_size, layer_size))
if use_layer_norm:
layers.append(nn.LayerNorm(layer_size))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Dropout(dropout_rate))
prev_size = layer_size
layers.append(nn.Linear(prev_size, output_size))
return nn.Sequential(*layers)
# Market regime detection head
self.regime_head = build_head_layers(
self.feature_size, regime_head, 4 # trending, ranging, volatile, mixed
)
# Price direction prediction head - outputs direction and confidence
self.price_direction_head = build_head_layers(
self.feature_size, price_direction_head, 2 # [direction, confidence]
)
# Direction activation (tanh for -1 to 1)
self.direction_activation = nn.Tanh()
# Confidence activation (sigmoid for 0 to 1)
self.confidence_activation = nn.Sigmoid()
# Volatility prediction head
self.volatility_head = build_head_layers(
self.feature_size, volatility_head, 4 # predicted volatility for 4 timeframes
)
# Main Q-value head (dueling architecture)
self.value_head = build_head_layers(
self.feature_size, value_head, 1 # Single value for dueling architecture
)
# Advantage head (dueling architecture)
self.advantage_head = build_head_layers(
self.feature_size, advantage_head, n_actions # Action advantages
)
# Initialize weights
self._initialize_weights()
# Log parameter count
total_params = sum(p.numel() for p in self.parameters())
logger.info(f"DQN Network initialized with {total_params:,} parameters (target: 50M)")
def _initialize_weights(self):
"""Initialize network weights using Xavier initialization"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, x):
"""Forward pass through the network"""
# Ensure input is properly shaped
if x.dim() > 2:
x = x.view(x.size(0), -1) # Flatten if needed
elif x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension if needed
# Feature extraction
features = self.feature_extractor(x)
# Multiple prediction heads
regime_pred = self.regime_head(features)
price_direction_raw = self.price_direction_head(features)
# Apply separate activations to direction and confidence
direction = self.direction_activation(price_direction_raw[:, 0:1]) # -1 to 1
confidence = self.confidence_activation(price_direction_raw[:, 1:2]) # 0 to 1
price_direction_pred = torch.cat([direction, confidence], dim=1) # [batch, 2]
volatility_pred = self.volatility_head(features)
# Dueling Q-network
value = self.value_head(features)
advantage = self.advantage_head(features)
# Combine value and advantage for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, regime_pred, price_direction_pred, volatility_pred, features
def act(self, state, explore=True):
"""
Select action using epsilon-greedy policy
Args:
state: Current state (numpy array or tensor)
explore: Whether to use epsilon-greedy exploration
Returns:
action_idx: Selected action index
confidence: Confidence score
action_probs: Action probabilities
"""
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
# Move to device
device = next(self.parameters()).device
state = state.to(device)
# Ensure proper shape
if state.dim() == 1:
state = state.unsqueeze(0)
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
# Price direction predictions are processed in the agent's act method
# This is just the network forward pass
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
# Select action (greedy for inference)
action_idx = torch.argmax(q_values, dim=1).item()
# Calculate confidence as max probability
confidence = float(action_probs[0, action_idx].item())
# Convert probabilities to list
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
return action_idx, confidence, probs_list
class DQNAgent:
"""
Deep Q-Network agent for trading
Uses Enhanced CNN model as the base network with GPU support for improved performance
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
learning_rate: float = 0.001,
epsilon: float = 1.0,
epsilon_min: float = 0.01,
epsilon_decay: float = 0.995,
buffer_size: int = 10000,
batch_size: int = 32,
target_update: int = 100,
priority_memory: bool = True,
device=None,
model_name: str = "dqn_agent",
enable_checkpoints: bool = True,
config: dict = None):
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = None # Removed dependency on utils.training_integration
self.episode_count = 0
self.best_reward = float('-inf')
self.reward_history = deque(maxlen=100)
self.checkpoint_frequency = 100 # Save checkpoint every 100 episodes
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
# Multi-dimensional state (like image or sequence)
self.state_dim = state_shape
else:
# 1D state
if isinstance(state_shape, tuple):
if len(state_shape) == 0:
self.state_dim = 1 # Safe default for empty tuple
else:
self.state_dim = state_shape[0]
else:
self.state_dim = state_shape
# Store parameters
self.n_actions = n_actions
self.learning_rate = learning_rate
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.buffer_size = buffer_size
self.batch_size = batch_size
self.target_update = target_update
# Set device for computation (default to GPU if available)
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
logger.info(f"DQN Agent using device: {self.device}")
# Initialize models with RL-specific network architecture
self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
# Ensure models are on the correct device
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
# Set models to eval mode (important for batch norm, dropout)
self.target_net.eval()
# Optimization components
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
self.criterion = nn.MSELoss()
# Experience replay memory
self.memory = []
self.positive_memory = [] # Special memory for storing good experiences
self.update_count = 0
# Extrema detection tracking
self.last_extrema_pred = {
'class': 2, # Default to "neither" (not extrema)
'confidence': 0.0,
'raw': None
}
self.extrema_memory = []
# DQN hyperparameters
self.gamma = 0.99 # Discount factor
# Initialize avg_reward for dashboard compatibility
self.avg_reward = 0.0 # Average reward tracking for dashboard
# Market regime adaptation weights
self.market_regime_weights = {
'trending': 1.0,
'sideways': 0.8,
'volatile': 1.2,
'bullish': 1.1,
'bearish': 1.1
}
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}")
if enable_checkpoints:
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
# Add this line to the __init__ method
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Price direction tracking - stores direction and confidence
self.last_price_direction = {
'direction': 0.0, # Single value between -1 and 1
'confidence': 0.0 # Single value between 0 and 1
}
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking
self.losses = []
self.no_improvement_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
self.uncertainty_threshold = 0.1 # When to stay neutral
def load_best_checkpoint(self):
"""Load the best checkpoint for this DQN agent"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
# Load model states
if 'policy_net_state_dict' in checkpoint:
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
if 'target_net_state_dict' in checkpoint:
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
if 'optimizer_state_dict' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load training state
if 'episode_count' in checkpoint:
self.episode_count = checkpoint['episode_count']
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon']
if 'best_reward' in checkpoint:
self.best_reward = checkpoint['best_reward']
logger.info(f"Loaded DQN checkpoint: {metadata.checkpoint_id}")
logger.info(f"Episode: {self.episode_count}, Best reward: {self.best_reward:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.episode_count += 1
self.reward_history.append(episode_reward)
# Calculate average reward over recent episodes
avg_reward = sum(self.reward_history) / len(self.reward_history)
# Update best reward
if episode_reward > self.best_reward:
self.best_reward = episode_reward
# Save checkpoint every N episodes or if forced
should_save = (
force_save or
self.episode_count % self.checkpoint_frequency == 0 or
episode_reward > self.best_reward * 0.95 # Within 5% of best
)
if should_save and self.training_integration:
return self.training_integration.save_rl_checkpoint(
rl_agent=self,
model_name=self.model_name,
episode=self.episode_count,
avg_reward=avg_reward,
best_reward=self.best_reward,
epsilon=self.epsilon,
total_pnl=0.0 # Default to 0, can be set by calling code
)
return False
except Exception as e:
logger.error(f"Error saving DQN checkpoint: {e}")
return False
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking
self.losses = []
self.no_improvement_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
self.uncertainty_threshold = 0.1 # When to stay neutral
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
self.device = device
try:
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
logger.info(f"Moved models to {self.device}")
return True
except Exception as e:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.policy_net = self.policy_net.to(device)
self.target_net = self.target_net.to(device)
return self
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""
Store experience in memory with prioritization
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
done: Whether episode is done
is_extrema: Whether this is a local extrema sample (for specialized learning)
"""
# Validate states before storing experience
if state is None or next_state is None:
logger.debug("Skipping experience storage: None state provided")
return
if isinstance(state, dict) and not state:
logger.debug("Skipping experience storage: empty state dictionary")
return
if isinstance(next_state, dict) and not next_state:
logger.debug("Skipping experience storage: empty next_state dictionary")
return
# Check if states are all zeros (invalid)
if hasattr(state, '__iter__') and all(f == 0 for f in np.array(state).flatten()):
logger.debug("Skipping experience storage: state is all zeros")
return
experience = (state, action, reward, next_state, done)
# Always add to main memory
self.memory.append(experience)
# Try to extract price change to analyze the experience
try:
# Extract price feature from sequence data (if available)
if len(state.shape) > 1: # 2D state [timeframes, features]
current_price = state[-1, -1] # Last timeframe, last feature
next_price = next_state[-1, -1]
else: # 1D state
current_price = state[-1] # Last feature
next_price = next_state[-1]
# Calculate price change - avoid division by zero
if np.isscalar(current_price) and current_price != 0:
price_change = (next_price - current_price) / current_price
elif isinstance(current_price, np.ndarray):
# Handle array case - protect against division by zero
with np.errstate(divide='ignore', invalid='ignore'):
price_change = (next_price - current_price) / current_price
# Replace infinities and NaNs with zeros
if isinstance(price_change, np.ndarray):
price_change = np.nan_to_num(price_change, nan=0.0, posinf=0.0, neginf=0.0)
else:
price_change = 0.0 if np.isnan(price_change) or np.isinf(price_change) else price_change
else:
price_change = 0.0
# Check if this is a significant price movement
if abs(price_change) > 0.002: # Significant price change
# Store in price movement memory
self.price_movement_memory.append(experience)
# Log significant price movements
direction = "UP" if price_change > 0 else "DOWN"
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
# For clear price movements, also duplicate in main memory to learn more
if abs(price_change) > 0.005: # Very significant movement
for _ in range(2): # Add 2 extra copies
self.memory.append(experience)
except Exception as e:
# Skip price movement analysis if it fails
pass
# Check if this is an extrema point based on our extrema detection head
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
# Class 0 = bottom, 1 = top, 2 = neither
# Only consider high confidence predictions
if self.last_extrema_pred['confidence'] > 0.7:
self.extrema_memory.append(experience)
# Log this special experience
extrema_type = "BOTTOM" if self.last_extrema_pred['class'] == 0 else "TOP"
logger.info(f"Stored {extrema_type} experience with reward {reward:.4f}")
# For tops and bottoms, also duplicate the experience in memory to learn more from it
for _ in range(2): # Add 2 extra copies
self.memory.append(experience)
# Explicitly marked extrema points also go to extrema memory
elif is_extrema:
self.extrema_memory.append(experience)
# Store positive experiences separately for prioritized replay
if reward > 0:
self.positive_memory.append(experience)
# For very good rewards, duplicate to learn more from them
if reward > 0.1:
for _ in range(min(int(reward * 10), 5)): # Cap at 5 extra copies for very high rewards
self.positive_memory.append(experience)
# Keep memory size under control
if len(self.memory) > self.buffer_size:
# Keep more recent experiences
self.memory = self.memory[-self.buffer_size:]
# Keep specialized memories under control too
if len(self.positive_memory) > self.buffer_size // 4:
self.positive_memory = self.positive_memory[-(self.buffer_size // 4):]
if len(self.extrema_memory) > self.buffer_size // 4:
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
if len(self.price_movement_memory) > self.buffer_size // 4:
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
def act(self, state: np.ndarray, explore=True, current_price=None, market_context=None) -> int:
"""
Choose action based on current state using 2-action system with intelligent position management
Args:
state: Current market state
explore: Whether to use epsilon-greedy exploration
current_price: Current market price for position management
market_context: Additional market context for decision making
Returns:
int: Action (0=BUY, 1=SELL)
"""
try:
# Validate state first - return early if empty/invalid/None
if state is None:
logger.warning("None state provided to act(), returning SELL action")
return 1 # SELL action (safe default)
if isinstance(state, dict) and not state:
logger.warning("Empty state dictionary provided to act(), returning SELL action")
return 1 # SELL action (safe default)
# Use the DQNNetwork's act method for consistent behavior
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
# Process price direction predictions from the network
# Get the raw predictions from the network's forward pass
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Apply epsilon-greedy exploration if requested
if explore and np.random.random() <= self.epsilon:
action_idx = np.random.choice(self.n_actions)
# Update tracking
if current_price:
self.recent_prices.append(current_price)
self.recent_actions.append(action_idx)
return action_idx
except Exception as e:
logger.error(f"Error in act method: {e}")
# Return default action (HOLD/SELL)
return 1
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
"""Choose action with confidence score adapted to market regime"""
try:
# Validate state first - return early if empty/invalid/None
if state is None:
logger.warning("None state provided to act_with_confidence(), returning safe defaults")
return 1, 0.1, [0.0, 0.9, 0.1] # SELL action with low confidence
if isinstance(state, dict) and not state:
logger.warning("Empty state dictionary provided to act_with_confidence(), returning safe defaults")
return 1, 0.0, [0.0, 1.0] # SELL action with zero confidence
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state_tensor = torch.FloatTensor(state)
device = next(self.policy_net.parameters()).device
state_tensor = state_tensor.to(device)
# Ensure proper shape
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
else:
state_tensor = state
# Get network outputs
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
# Process price direction predictions
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
# Select action (greedy for inference)
action_idx = torch.argmax(q_values, dim=1).item()
# Calculate confidence as max probability
base_confidence = float(action_probs[0, action_idx].item())
# Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
# Convert probabilities to list
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
# Return action, confidence, and probabilities (for orchestrator compatibility)
return int(action_idx), float(adapted_confidence), probs_list
except Exception as e:
logger.error(f"Error in act_with_confidence: {e}")
# Return default action with low confidence
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
def process_price_direction_predictions(self, price_direction_pred: torch.Tensor) -> Dict[str, float]:
"""
Process price direction predictions and convert to standardized format
Args:
price_direction_pred: Tensor of shape (batch_size, 2) containing [direction, confidence]
Returns:
Dict with direction (-1 to 1) and confidence (0 to 1)
"""
try:
if price_direction_pred is None or price_direction_pred.numel() == 0:
return self.last_price_direction
# Extract direction and confidence values
direction_value = float(price_direction_pred[0, 0].item()) # -1 to 1
confidence_value = float(price_direction_pred[0, 1].item()) # 0 to 1
# Update last price direction
self.last_price_direction = {
'direction': direction_value,
'confidence': confidence_value
}
return self.last_price_direction
except Exception as e:
logger.error(f"Error processing price direction predictions: {e}")
return self.last_price_direction
def get_price_direction_vector(self) -> Dict[str, float]:
"""
Get the current price direction and confidence
Returns:
Dict with direction (-1 to 1) and confidence (0 to 1)
"""
return self.last_price_direction
def get_price_direction_summary(self) -> Dict[str, Any]:
"""
Get a summary of price direction prediction
Returns:
Dict containing direction and confidence information
"""
try:
direction_value = self.last_price_direction['direction']
confidence_value = self.last_price_direction['confidence']
# Convert to discrete direction
if direction_value > 0.1:
direction_label = "UP"
discrete_direction = 1
elif direction_value < -0.1:
direction_label = "DOWN"
discrete_direction = -1
else:
direction_label = "SIDEWAYS"
discrete_direction = 0
return {
'direction_value': float(direction_value),
'confidence_value': float(confidence_value),
'direction_label': direction_label,
'discrete_direction': discrete_direction,
'strength': abs(float(direction_value)),
'weighted_strength': abs(float(direction_value)) * float(confidence_value)
}
except Exception as e:
logger.error(f"Error calculating price direction summary: {e}")
return {
'direction_value': 0.0,
'confidence_value': 0.0,
'direction_label': "SIDEWAYS",
'discrete_direction': 0,
'strength': 0.0,
'weighted_strength': 0.0
}
except Exception as e:
logger.error(f"Error in act_with_confidence: {e}")
# Return default action with low confidence
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
"""
Determine action based on current position and confidence thresholds
This implements the intelligent position management where:
- When neutral: Need high confidence to enter position
- When in position: Need lower confidence to exit
- Different thresholds for entry vs exit
"""
# Apply epsilon-greedy exploration
if explore and np.random.random() <= self.epsilon:
return np.random.choice([0, 1])
# Get the dominant signal - FIXED ACTION MAPPING: 0=BUY, 1=SELL
dominant_action = 0 if buy_conf > sell_conf else 1
dominant_confidence = max(buy_conf, sell_conf)
# Decision logic based on current position
if self.current_position == 0: # No position - need high confidence to enter
if dominant_confidence >= self.entry_confidence_threshold:
# Strong enough signal to enter position
if dominant_action == 0: # BUY signal (action 0)
self.current_position = 1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
return 0 # Return BUY action (0)
else: # SELL signal (action 1)
self.current_position = -1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
return 1 # Return SELL action (1)
else:
# Not confident enough to enter position
return None
elif self.current_position > 0: # Long position
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
# SELL signal (action 1) with enough confidence to close long position
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = 0.0
self.position_entry_price = 0.0
self.position_entry_time = None
return 1 # Return SELL action (1)
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
# Very strong SELL signal - close long and enter short
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = -1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
return 1 # Return SELL action (1)
else:
# Hold the long position
return None
elif self.current_position < 0: # Short position
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
# BUY signal (action 0) with enough confidence to close short position
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = 0.0
self.position_entry_price = 0.0
self.position_entry_time = None
return 0 # Return BUY action (0)
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
# Very strong BUY signal - close short and enter long
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = 1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
return 0 # Return BUY action (0)
else:
# Hold the short position
return None
return None
def _safe_cnn_forward(self, network, states):
"""Safely call CNN forward method ensuring we always get 5 return values"""
try:
result = network(states)
if isinstance(result, tuple) and len(result) == 5:
return result
elif isinstance(result, tuple) and len(result) == 1:
# Handle case where only q_values are returned (like in empty tensor case)
q_values = result[0]
batch_size = q_values.size(0)
device = q_values.device
default_extrema = torch.zeros(batch_size, 3, device=device)
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return q_values, default_extrema, default_price, default_features, default_advanced
else:
# Fallback: create all default tensors
batch_size = states.size(0)
device = states.device
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
default_extrema = torch.zeros(batch_size, 3, device=device)
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced
except Exception as e:
logger.error(f"Error in CNN forward pass: {e}")
# Fallback: create all default tensors
batch_size = states.size(0)
device = states.device
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
default_extrema = torch.zeros(batch_size, 3, device=device)
default_price = torch.zeros(batch_size, 1, device=device)
default_features = torch.zeros(batch_size, 1024, device=device)
default_advanced = torch.zeros(batch_size, 1, device=device)
return default_q_values, default_extrema, default_price, default_features, default_advanced
def replay(self, experiences=None):
"""Train the model using experiences from memory"""
# Don't train if not in training mode
if not self.training:
return 0.0
# If no experiences provided, sample from memory
if experiences is None:
# Skip if memory is too small (allow early training for GPU warmup)
min_required = min(getattr(self, 'batch_size', 32), 16)
if len(self.memory) < min_required:
return 0.0
# Sample random mini-batch from memory
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
experiences = [self.memory[i] for i in indices]
# Validate experiences before processing
if not experiences or len(experiences) == 0:
logger.warning("No experiences provided for training")
return 0.0
# Sanitize and validate experiences
valid_experiences = []
for i, exp in enumerate(experiences):
try:
if len(exp) != 5:
logger.debug(f"Invalid experience format at index {i}: expected 5 elements, got {len(exp)}")
continue
state, action, reward, next_state, done = exp
# Validate state
state = self._validate_and_fix_state(state)
next_state = self._validate_and_fix_state(next_state)
if state is None or next_state is None:
continue
# Validate action
if isinstance(action, dict):
action = action.get('action', action.get('value', 0))
action = int(action) if action is not None else 0
action = max(0, min(action, self.n_actions - 1)) # Clamp to valid range
# Validate reward
if isinstance(reward, dict):
reward = reward.get('reward', reward.get('value', 0.0))
reward = float(reward) if reward is not None else 0.0
# Validate done flag
done = bool(done) if done is not None else False
valid_experiences.append((state, action, reward, next_state, done))
except Exception as e:
logger.debug(f"Error processing experience {i}: {e}")
continue
if len(valid_experiences) == 0:
logger.warning("No valid experiences after sanitization")
return 0.0
# Use validated experiences for training
experiences = valid_experiences
# Extract components
states, actions, rewards, next_states, dones = zip(*experiences)
# Convert to tensors with proper validation
try:
# Ensure all data is on CPU first, then move to device
states_array = np.array(states, dtype=np.float32)
actions_array = np.array(actions, dtype=np.int64)
rewards_array = np.array(rewards, dtype=np.float32)
next_states_array = np.array(next_states, dtype=np.float32)
dones_array = np.array(dones, dtype=np.float32)
# Convert to tensors and move to device
states = torch.from_numpy(states_array).to(self.device)
actions = torch.from_numpy(actions_array).to(self.device)
rewards = torch.from_numpy(rewards_array).to(self.device)
next_states = torch.from_numpy(next_states_array).to(self.device)
dones = torch.from_numpy(dones_array).to(self.device)
# Final validation of tensor shapes
if states.shape[0] == 0 or actions.shape[0] == 0:
logger.warning("Empty tensors after conversion")
return 0.0
# Ensure all tensors have the same batch size
batch_size = states.shape[0]
if not all(tensor.shape[0] == batch_size for tensor in [actions, rewards, next_states, dones]):
logger.warning("Inconsistent batch sizes across tensors")
return 0.0
except Exception as e:
logger.error(f"Error converting experiences to tensors: {e}")
return 0.0
# Always use standard training to fix gradient issues
loss = self._replay_standard(states, actions, rewards, next_states, dones)
# Update epsilon
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
# Update statistics
self.losses.append(loss)
if len(self.losses) > 1000:
self.losses = self.losses[-500:] # Keep only recent losses
return loss
def _validate_and_fix_state(self, state):
"""Validate and fix state to ensure it has correct dimensions and no empty data"""
try:
# Convert to numpy if needed
if isinstance(state, torch.Tensor):
state = state.detach().cpu().numpy()
elif not isinstance(state, np.ndarray):
# Check if state is a dict or complex object
if isinstance(state, dict):
logger.error(f"State is a dict: {state}")
# Handle empty dictionary case
if not state:
logger.error("Empty state dictionary received, using default state")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Extract numerical values from dict if possible
if 'features' in state:
state = state['features']
elif 'state' in state:
state = state['state']
else:
# Try to extract all numerical values using the helper method
numerical_values = self._extract_numeric_from_dict(state)
if numerical_values:
state = np.array(numerical_values, dtype=np.float32)
else:
logger.error("No numerical values found in state dict, using default state")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
else:
try:
state = np.array(state, dtype=np.float32)
except (ValueError, TypeError) as e:
logger.error(f"Cannot convert state to numpy array: {type(state)}, {e}")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Flatten if multi-dimensional
if state.ndim > 1:
state = state.flatten()
# Check for empty or invalid state
if state.size == 0:
logger.warning("Empty state detected, using default")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Check for NaN or infinite values
if np.any(np.isnan(state)) or np.any(np.isinf(state)):
logger.warning("NaN or infinite values in state, replacing with zeros")
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
# Ensure correct dimensions
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
expected_size = int(expected_size)
if len(state) != expected_size:
if len(state) < expected_size:
# Pad with zeros
padded_state = np.zeros(expected_size, dtype=np.float32)
padded_state[:len(state)] = state
state = padded_state
else:
# Truncate
state = state[:expected_size]
return state.astype(np.float32)
except Exception as e:
logger.error(f"Error validating state: {e}")
# Return default state as fallback
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
def _extract_numeric_from_dict(self, data_dict):
"""Recursively extract numerical values from nested dictionaries"""
numerical_values = []
try:
for key, value in data_dict.items():
if isinstance(value, (int, float)):
numerical_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numerical_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numerical_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
nested_values = self._extract_numeric_from_dict(value)
numerical_values.extend(nested_values)
except Exception as e:
logger.debug(f"Error extracting numeric values from dict: {e}")
return numerical_values
def _replay_standard(self, states, actions, rewards, next_states, dones):
"""Standard training step without mixed precision"""
try:
# Validate input tensors
if states.shape[0] == 0:
logger.warning("Empty batch in _replay_standard")
return 0.0
# Ensure model is in training mode for gradients
self.policy_net.train()
# Get current Q values - use the updated forward method
q_values_output = self.policy_net(states)
if isinstance(q_values_output, tuple):
current_q_values_all = q_values_output[0] # Extract Q-values from tuple
else:
current_q_values_all = q_values_output
current_q_values = current_q_values_all.gather(1, actions.unsqueeze(1)).squeeze(1)
# Enhanced Double DQN implementation
with torch.no_grad():
if self.use_double_dqn:
# Double DQN: Use policy network to select actions, target network to evaluate
policy_output = self.policy_net(next_states)
policy_q_values = policy_output[0] if isinstance(policy_output, tuple) else policy_output
next_actions = policy_q_values.argmax(1)
target_output = self.target_net(next_states)
target_q_values_all = target_output[0] if isinstance(target_output, tuple) else target_output
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN: Use target network for both selection and evaluation
target_output = self.target_net(next_states)
target_q_values = target_output[0] if isinstance(target_output, tuple) else target_output
next_q_values = target_q_values.max(1)[0]
# Ensure tensor shapes are consistent
batch_size = states.shape[0]
if rewards.shape[0] != batch_size or next_q_values.shape[0] != batch_size:
logger.warning(f"Shape mismatch in replay: batch_size={batch_size}, rewards={rewards.shape}, next_q_values={next_q_values.shape}")
min_size = min(batch_size, rewards.shape[0], next_q_values.shape[0])
rewards = rewards[:min_size]
dones = dones[:min_size]
next_q_values = next_q_values[:min_size]
current_q_values = current_q_values[:min_size]
# Calculate target Q values
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss for Q value - ensure tensors require gradients
if not current_q_values.requires_grad:
logger.warning("Current Q values do not require gradients")
# Force training mode
self.policy_net.train()
return 0.0
q_loss = self.criterion(current_q_values, target_q_values.detach())
# Calculate auxiliary losses and add to Q-loss
total_loss = q_loss
# Add auxiliary losses if available
try:
# Get additional predictions from forward pass
if isinstance(q_values_output, tuple) and len(q_values_output) >= 5:
current_regime_pred = q_values_output[1]
current_price_pred = q_values_output[2]
current_volatility_pred = q_values_output[3]
current_extrema_pred = current_regime_pred # Use regime as extrema proxy for now
# Price direction loss
if current_price_pred is not None and current_price_pred.shape[0] > 0:
price_direction_loss = self._calculate_price_direction_loss(current_price_pred, rewards, actions)
if price_direction_loss is not None:
total_loss = total_loss + 0.2 * price_direction_loss
# Extrema loss
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
extrema_loss = self._calculate_extrema_loss(current_extrema_pred, rewards, actions)
if extrema_loss is not None:
total_loss = total_loss + 0.1 * extrema_loss
except Exception as e:
logger.debug(f"Could not add auxiliary loss in standard training: {e}")
# Reset gradients
self.optimizer.zero_grad()
# Ensure total loss requires gradients
if not total_loss.requires_grad:
logger.warning("Total loss does not require gradients - policy network may not be in training mode")
self.policy_net.train() # Ensure training mode
return 0.0
# Backward pass
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
# Check if gradients are valid
has_valid_gradients = False
for param in self.policy_net.parameters():
if param.grad is not None and torch.any(torch.isfinite(param.grad)):
has_valid_gradients = True
break
if not has_valid_gradients:
logger.warning("No valid gradients found, skipping optimizer step")
return 0.0
# Update weights
self.optimizer.step()
# Update target network periodically
self.training_steps += 1
if self.training_steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
logger.debug(f"Target network updated at step {self.training_steps}")
return total_loss.item()
except Exception as e:
logger.error(f"Error in standard replay: {e}")
return 0.0
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
"""Mixed precision training step"""
if not self.use_mixed_precision:
logger.warning("Mixed precision not available, falling back to standard replay")
return self._replay_standard(states, actions, rewards, next_states, dones)
try:
# Validate input tensors
if states.shape[0] == 0:
logger.warning("Empty batch in _replay_mixed_precision")
return 0.0
# Zero gradients
self.optimizer.zero_grad()
# Forward pass with amp autocasting
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
with torch.cuda.amp.autocast():
# Get current Q values and predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network
with torch.no_grad():
if self.use_double_dqn:
# Double DQN
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
next_actions = policy_q_values.argmax(1)
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
next_q_values = next_q_values.max(1)[0]
# Ensure consistent shapes
batch_size = states.shape[0]
if rewards.shape[0] != batch_size or next_q_values.shape[0] != batch_size:
logger.warning(f"Shape mismatch in mixed precision replay")
min_size = min(batch_size, rewards.shape[0], next_q_values.shape[0])
rewards = rewards[:min_size]
dones = dones[:min_size]
next_q_values = next_q_values[:min_size]
current_q_values = current_q_values[:min_size]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute Q-value loss (primary task)
q_loss = nn.MSELoss()(current_q_values, target_q_values.detach())
# Initialize loss with q_loss
loss = q_loss
# Add auxiliary losses if available
try:
# Price direction loss
if current_price_pred is not None and current_price_pred.shape[0] > 0:
price_direction_loss = self._calculate_price_direction_loss(current_price_pred, rewards, actions)
if price_direction_loss is not None:
loss = loss + 0.2 * price_direction_loss
# Extrema loss
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
extrema_loss = self._calculate_extrema_loss(current_extrema_pred, rewards, actions)
if extrema_loss is not None:
loss = loss + 0.1 * extrema_loss
except Exception as e:
logger.debug(f"Could not add auxiliary loss in mixed precision: {e}")
# Check if loss requires gradients
if not loss.requires_grad:
logger.warning("Loss does not require gradients in mixed precision training")
return 0.0
# Scale and backward pass
self.scaler.scale(loss).backward()
# Unscale gradients and clip
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
# Check for valid gradients
has_valid_gradients = False
for param in self.policy_net.parameters():
if param.grad is not None and torch.any(torch.isfinite(param.grad)):
has_valid_gradients = True
break
if not has_valid_gradients:
logger.warning("No valid gradients in mixed precision training")
self.scaler.update() # Still update scaler
return 0.0
# Optimizer step with scaler
self.scaler.step(self.optimizer)
self.scaler.update()
# Update target network
self.training_steps += 1
if self.training_steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
logger.debug(f"Target network updated at step {self.training_steps}")
return loss.item()
except Exception as e:
logger.error(f"Error in mixed precision replay: {e}")
return 0.0
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training function specifically for extrema points
Args:
states: Batch of states at extrema points
actions: Batch of actions
rewards: Batch of rewards
next_states: Batch of next states
dones: Batch of done flags
Returns:
float: Training loss
"""
# Convert to numpy arrays if not already
if not isinstance(states, np.ndarray):
states = np.array(states)
if not isinstance(actions, np.ndarray):
actions = np.array(actions)
if not isinstance(rewards, np.ndarray):
rewards = np.array(rewards)
if not isinstance(next_states, np.ndarray):
next_states = np.array(next_states)
if not isinstance(dones, np.ndarray):
dones = np.array(dones, dtype=np.float32)
# Normalize states
states = np.vstack([self._normalize_state(s) for s in states])
next_states = np.vstack([self._normalize_state(s) for s in next_states])
# Convert to torch tensors and move to device
states_tensor = torch.FloatTensor(states).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
dones_tensor = torch.FloatTensor(dones).to(self.device)
# Choose training method based on precision mode
if self.use_mixed_precision:
return self._replay_mixed_precision(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
else:
return self._replay_standard(
states_tensor, actions_tensor, rewards_tensor,
next_states_tensor, dones_tensor
)
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
"""Normalize the state data to prevent numerical issues"""
# Handle NaN and infinite values
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
# Check if state is 1D array (happens in some environments)
if len(state.shape) == 1:
# If 1D, we need to normalize the whole array
normalized_state = state.copy()
# Convert any timestamp or non-numeric data to float
for i in range(len(normalized_state)):
# Check for timestamp-like objects
if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')):
# Convert timestamp to float (seconds since epoch)
normalized_state[i] = float(normalized_state[i].timestamp())
elif not isinstance(normalized_state[i], (int, float, np.number)):
# Set non-numeric data to 0
normalized_state[i] = 0.0
# Ensure all values are float
normalized_state = normalized_state.astype(np.float32)
# Simple min-max normalization for 1D state
state_min = np.min(normalized_state)
state_max = np.max(normalized_state)
if state_max > state_min:
normalized_state = (normalized_state - state_min) / (state_max - state_min)
return normalized_state
# Handle 2D arrays
normalized_state = np.zeros_like(state, dtype=np.float32)
# Convert any timestamp or non-numeric data to float
for i in range(state.shape[0]):
for j in range(state.shape[1]):
if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')):
# Convert timestamp to float (seconds since epoch)
normalized_state[i, j] = float(state[i, j].timestamp())
elif isinstance(state[i, j], (int, float, np.number)):
normalized_state[i, j] = state[i, j]
else:
# Set non-numeric data to 0
normalized_state[i, j] = 0.0
# Loop through each timeframe's features in the combined state
feature_count = state.shape[1] // len(self.timeframes)
for tf_idx in range(len(self.timeframes)):
start_idx = tf_idx * feature_count
end_idx = start_idx + feature_count
# Extract this timeframe's features
tf_features = normalized_state[:, start_idx:end_idx]
# Normalize OHLCV data by the first close price in the window
# This makes price movements relative rather than absolute
price_idx = 3 # Assuming close price is at index 3
if price_idx < tf_features.shape[1]:
reference_price = np.mean(tf_features[:, price_idx])
if reference_price != 0:
# Normalize price-related columns (OHLC)
for i in range(4): # First 4 columns are OHLC
if i < tf_features.shape[1]:
normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price
# Normalize volume using mean and std
vol_idx = 4 # Assuming volume is at index 4
if vol_idx < tf_features.shape[1]:
vol_mean = np.mean(tf_features[:, vol_idx])
vol_std = np.std(tf_features[:, vol_idx])
if vol_std > 0:
normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std
else:
normalized_state[:, start_idx + vol_idx] = 0
# Other features (technical indicators) - normalize with min-max scaling
for i in range(5, feature_count):
if i < tf_features.shape[1]:
feature_min = np.min(tf_features[:, i])
feature_max = np.max(tf_features[:, i])
if feature_max > feature_min:
normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min)
else:
normalized_state[:, start_idx + i] = 0
return normalized_state
def update_realtime_tick_features(self, tick_features):
"""Update with real-time tick features from tick processor"""
try:
if tick_features is not None:
self.realtime_tick_features = tick_features
# Log high-confidence tick features
if tick_features.get('confidence', 0) > 0.8:
logger.debug(f"High-confidence tick features updated: confidence={tick_features['confidence']:.3f}")
except Exception as e:
logger.error(f"Error updating real-time tick features: {e}")
def _enhance_state_with_tick_features(self, state: np.ndarray) -> np.ndarray:
"""Enhance state with real-time tick features if available"""
try:
if self.realtime_tick_features is None:
return state
# Extract neural features from tick processor
neural_features = self.realtime_tick_features.get('neural_features', np.array([]))
volume_features = self.realtime_tick_features.get('volume_features', np.array([]))
microstructure_features = self.realtime_tick_features.get('microstructure_features', np.array([]))
confidence = self.realtime_tick_features.get('confidence', 0.0)
# Combine tick features - make them compact to match state dimensions
tick_features = np.concatenate([
neural_features[:3] if len(neural_features) >= 3 else np.zeros(3), # Take first 3 neural features
volume_features[:1] if len(volume_features) >= 1 else np.zeros(1), # Take first volume feature
microstructure_features[:1] if len(microstructure_features) >= 1 else np.zeros(1), # Take first microstructure feature
])
# Weight the tick features
weighted_tick_features = tick_features * self.tick_feature_weight
# Enhance the state by adding tick features to each timeframe
if len(state.shape) == 1:
# 1D state - append tick features
enhanced_state = np.concatenate([state, weighted_tick_features])
else:
# 2D state - add tick features to each timeframe row
num_timeframes, num_features = state.shape
# Ensure tick features match the number of original features
if len(weighted_tick_features) != num_features:
# Pad or truncate tick features to match state feature dimension
if len(weighted_tick_features) < num_features:
# Pad with zeros
padded_features = np.zeros(num_features)
padded_features[:len(weighted_tick_features)] = weighted_tick_features
weighted_tick_features = padded_features
else:
# Truncate to match
weighted_tick_features = weighted_tick_features[:num_features]
# Add tick features to the last row (most recent timeframe)
enhanced_state = state.copy()
enhanced_state[-1, :] += weighted_tick_features # Add to last timeframe
return enhanced_state
except Exception as e:
logger.error(f"Error enhancing state with tick features: {e}")
return state
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
"""Update learning metrics and perform learning rate adjustments if needed"""
# Update average reward with exponential moving average
if self.avg_reward == 0:
self.avg_reward = episode_reward
else:
self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward
# Check if we're making sufficient progress
if episode_reward > (1 + best_reward_threshold) * self.best_reward:
self.best_reward = episode_reward
self.no_improvement_count = 0
return True # Improved
else:
self.no_improvement_count += 1
# If no improvement for a while, adjust learning rate
if self.no_improvement_count >= 10:
current_lr = self.optimizer.param_groups[0]['lr']
new_lr = current_lr * 0.5
if new_lr >= 1e-6: # Don't reduce below minimum threshold
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
logger.info(f"Reducing learning rate from {current_lr} to {new_lr}")
self.no_improvement_count = 0
return False # No improvement
def save(self, path: str = None):
"""Save model and agent state using unified registry"""
try:
from NN.training.model_manager import save_model
# Use unified registry if no path or if it's a models/ path
if path is None or path.startswith('models/'):
model_name = "dqn_agent"
if path:
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
# Prepare full agent state
agent_state = {
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
'optimizer_state': self.optimizer.state_dict(),
'best_reward': self.best_reward,
'avg_reward': self.avg_reward,
'policy_net_state': self.policy_net.state_dict(),
'target_net_state': self.target_net.state_dict()
}
success = save_model(
model=self.policy_net, # Save policy net as main model
model_name=model_name,
model_type='dqn',
metadata={'full_agent_state': agent_state}
)
if success:
logger.info(f"DQN agent saved to unified registry: {model_name}")
return
else:
# Legacy direct file save
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save policy network
self.policy_net.save(f"{path}_policy")
# Save target network
self.target_net.save(f"{path}_target")
# Save agent state
state = {
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
'optimizer_state': self.optimizer.state_dict(),
'best_reward': self.best_reward,
'avg_reward': self.avg_reward
}
torch.save(state, f"{path}_agent_state.pt")
logger.info(f"Agent state saved to {path}_agent_state.pt (legacy mode)")
except Exception as e:
logger.error(f"Failed to save DQN agent: {e}")
def load(self, path: str = None):
"""Load model and agent state from unified registry or file"""
try:
from NN.training.model_manager import load_model
# Use unified registry if no path or if it's a models/ path
if path is None or path.startswith('models/'):
model_name = "dqn_agent"
if path:
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
model = load_model(model_name, 'dqn')
if model is None:
logger.warning(f"Could not load DQN agent {model_name} from unified registry")
return
# Load full agent state from metadata
registry = get_model_registry()
if model_name in registry.metadata['models']:
model_data = registry.metadata['models'][model_name]
if 'full_agent_state' in model_data:
agent_state = model_data['full_agent_state']
# Restore agent state
self.epsilon = agent_state['epsilon']
self.update_count = agent_state['update_count']
self.losses = agent_state['losses']
self.optimizer.load_state_dict(agent_state['optimizer_state'])
# Load additional metrics if they exist
if 'best_reward' in agent_state:
self.best_reward = agent_state['best_reward']
if 'avg_reward' in agent_state:
self.avg_reward = agent_state['avg_reward']
# Load network states
if 'policy_net_state' in agent_state:
self.policy_net.load_state_dict(agent_state['policy_net_state'])
if 'target_net_state' in agent_state:
self.target_net.load_state_dict(agent_state['target_net_state'])
logger.info(f"DQN agent loaded from unified registry: {model_name}")
return
return
else:
# Legacy direct file load
# Load policy network
self.policy_net.load(f"{path}_policy")
# Load target network
self.target_net.load(f"{path}_target")
# Load agent state
try:
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
self.epsilon = agent_state['epsilon']
self.update_count = agent_state['update_count']
self.losses = agent_state['losses']
self.optimizer.load_state_dict(agent_state['optimizer_state'])
# Load additional metrics if they exist
if 'best_reward' in agent_state:
self.best_reward = agent_state['best_reward']
if 'avg_reward' in agent_state:
self.avg_reward = agent_state['avg_reward']
logger.info(f"Agent state loaded from {path}_agent_state.pt (legacy mode)")
except FileNotFoundError:
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
except Exception as e:
logger.error(f"Failed to load DQN agent: {e}")
def get_position_info(self):
"""Get current position information"""
return {
'position': self.current_position,
'entry_price': self.position_entry_price,
'entry_time': self.position_entry_time,
'entry_threshold': self.entry_confidence_threshold,
'exit_threshold': self.exit_confidence_threshold
}
def _calculate_price_direction_loss(self, price_direction_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Calculate loss for price direction predictions
Args:
price_direction_pred: Tensor of shape [batch, 2] containing [direction, confidence]
rewards: Tensor of shape [batch] containing rewards
actions: Tensor of shape [batch] containing actions
Returns:
Price direction loss tensor
"""
try:
if price_direction_pred.size(1) != 2:
return None
batch_size = price_direction_pred.size(0)
# Extract direction and confidence predictions
direction_pred = price_direction_pred[:, 0] # -1 to 1
confidence_pred = price_direction_pred[:, 1] # 0 to 1
# Create targets based on rewards and actions
with torch.no_grad():
# Direction targets: 1 if reward > 0 and action is BUY, -1 if reward > 0 and action is SELL, 0 otherwise
direction_targets = torch.zeros(batch_size, device=price_direction_pred.device)
for i in range(batch_size):
if rewards[i] > 0.01: # Positive reward threshold
if actions[i] == 0: # BUY action
direction_targets[i] = 1.0 # UP
elif actions[i] == 1: # SELL action
direction_targets[i] = -1.0 # DOWN
# else: targets remain 0 (sideways)
# Confidence targets: based on reward magnitude (higher reward = higher confidence)
confidence_targets = torch.abs(rewards).clamp(0, 1)
# Calculate losses for each component
direction_loss = F.mse_loss(direction_pred, direction_targets)
confidence_loss = F.mse_loss(confidence_pred, confidence_targets)
# Combined loss (direction is more important than confidence)
total_loss = direction_loss + 0.3 * confidence_loss
return total_loss
except Exception as e:
logger.debug(f"Error calculating price direction loss: {e}")
return None
def _calculate_extrema_loss(self, extrema_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Calculate loss for extrema predictions
Args:
extrema_pred: Extrema predictions
rewards: Tensor containing rewards
actions: Tensor containing actions
Returns:
Extrema loss tensor
"""
try:
batch_size = extrema_pred.size(0)
# Create targets based on reward patterns
with torch.no_grad():
extrema_targets = torch.ones(batch_size, dtype=torch.long, device=extrema_pred.device) * 2 # Default to "neither"
for i in range(batch_size):
# High positive reward suggests we're at a good entry point (potential bottom for BUY, top for SELL)
if rewards[i] > 0.05:
if actions[i] == 0: # BUY action
extrema_targets[i] = 0 # Bottom
elif actions[i] == 1: # SELL action
extrema_targets[i] = 1 # Top
# Calculate cross-entropy loss
if extrema_pred.size(1) >= 3:
extrema_loss = F.cross_entropy(extrema_pred[:, :3], extrema_targets)
else:
extrema_loss = F.cross_entropy(extrema_pred, extrema_targets)
return extrema_loss
except Exception as e:
logger.debug(f"Error calculating extrema loss: {e}")
return None
def get_enhanced_training_stats(self):
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
return {
'buffer_size': len(self.memory),
'epsilon': self.epsilon,
'avg_reward': self.avg_reward,
'best_reward': self.best_reward,
'recent_rewards': list(self.recent_rewards) if hasattr(self, 'recent_rewards') else [],
'no_improvement_count': self.no_improvement_count,
# Enhanced statistics from EnhancedDQNAgent
'training_steps': self.training_steps,
'avg_td_error': np.mean(self.td_errors[-100:]) if self.td_errors else 0.0,
'recent_losses': self.losses[-10:] if self.losses else [],
'epsilon_trend': self.epsilon_history[-20:] if self.epsilon_history else [],
'specialized_buffers': {
'extrema_memory': len(self.extrema_memory),
'positive_memory': len(self.positive_memory),
'price_movement_memory': len(self.price_movement_memory)
},
'market_regime_weights': self.market_regime_weights,
'use_double_dqn': self.use_double_dqn,
'use_prioritized_replay': self.use_prioritized_replay,
'gradient_clip_norm': self.gradient_clip_norm,
'target_update_frequency': self.target_update_freq
}
def get_params_count(self):
"""Get total number of parameters in the DQN model"""
total_params = 0
for param in self.policy_net.parameters():
total_params += param.numel()
return total_params
def _sanitize_state_data(self, state):
"""Sanitize state data to ensure it's a proper numeric array"""
try:
# If state is already a numpy array, return it
if isinstance(state, np.ndarray):
# Check for empty array
if state.size == 0:
logger.warning("Received empty numpy array state. Using fallback dimensions.")
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Check for non-numeric data and handle it
if state.dtype == object:
# Convert object array to float array
sanitized = np.zeros_like(state, dtype=np.float32)
for i in range(state.shape[0]):
if len(state.shape) > 1:
for j in range(state.shape[1]):
sanitized[i, j] = self._extract_numeric_value(state[i, j])
else:
sanitized[i] = self._extract_numeric_value(state[i])
return sanitized
else:
return state.astype(np.float32)
# If state is a list or tuple, convert to array
elif isinstance(state, (list, tuple)):
# Check for empty list/tuple
if len(state) == 0:
logger.warning("Received empty list/tuple state. Using fallback dimensions.")
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Recursively sanitize each element
sanitized = []
for item in state:
if isinstance(item, (list, tuple)):
sanitized_row = []
for sub_item in item:
sanitized_row.append(self._extract_numeric_value(sub_item))
sanitized.append(sanitized_row)
else:
sanitized.append(self._extract_numeric_value(item))
result = np.array(sanitized, dtype=np.float32)
# Check if result is empty and provide fallback
if result.size == 0:
logger.warning("Sanitized state resulted in empty array. Using fallback dimensions.")
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
return result
# If state is a dict, try to extract values
elif isinstance(state, dict):
# Try to extract meaningful values from dict
values = []
for key in sorted(state.keys()): # Sort for consistency
values.append(self._extract_numeric_value(state[key]))
return np.array(values, dtype=np.float32)
# If state is a single value, make it an array
else:
return np.array([self._extract_numeric_value(state)], dtype=np.float32)
except Exception as e:
logger.warning(f"Error sanitizing state data: {e}. Using zero array with expected dimensions.")
# Return a zero array as fallback with the expected state dimension
# Use the state_dim from initialization, fallback to 403 if not available
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
def _extract_numeric_value(self, value):
"""Extract a numeric value from various data types"""
try:
# Handle None values
if value is None:
return 0.0
# Handle numeric types
if isinstance(value, (int, float, np.number)):
return float(value)
# Handle dict values
elif isinstance(value, dict):
# Try common keys for numeric data
for key in ['value', 'price', 'close', 'last', 'amount', 'quantity']:
if key in value:
return self._extract_numeric_value(value[key])
# If no common keys, try to get first numeric value
for v in value.values():
if isinstance(v, (int, float, np.number)):
return float(v)
return 0.0
# Handle string values that might be numeric
elif isinstance(value, str):
try:
return float(value)
except:
return 0.0
# Handle datetime objects
elif hasattr(value, 'timestamp'):
return float(value.timestamp())
# Handle boolean values
elif isinstance(value, bool):
return float(value)
# Handle list/tuple - take first numeric value
elif isinstance(value, (list, tuple)) and len(value) > 0:
return self._extract_numeric_value(value[0])
else:
return 0.0
except:
return 0.0
def _extract_numeric_from_dict(self, data_dict):
"""Recursively extract all numeric values from a dictionary"""
numeric_values = []
try:
for key, value in data_dict.items():
if isinstance(value, (int, float)):
numeric_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numeric_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numeric_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
nested_values = self._extract_numeric_from_dict(value)
numeric_values.extend(nested_values)
elif isinstance(value, torch.Tensor):
try:
numeric_values.append(float(value.item()))
except Exception:
continue
except Exception as e:
logger.debug(f"Error extracting numeric values from dict: {e}")
return numeric_values