1664 lines
75 KiB
Python
1664 lines
75 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
from typing import Tuple, List
|
|
import os
|
|
import sys
|
|
import logging
|
|
import torch.nn.functional as F
|
|
import time
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
# Import checkpoint management
|
|
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
|
from utils.training_integration import get_training_integration
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DQNNetwork(nn.Module):
|
|
"""
|
|
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
|
Handles 7850 input features from multi-timeframe, multi-asset data
|
|
"""
|
|
def __init__(self, input_dim: int, n_actions: int):
|
|
super(DQNNetwork, self).__init__()
|
|
|
|
# Handle different input dimension formats
|
|
if isinstance(input_dim, (tuple, list)):
|
|
if len(input_dim) == 1:
|
|
self.input_size = input_dim[0]
|
|
else:
|
|
self.input_size = np.prod(input_dim) # Flatten multi-dimensional input
|
|
else:
|
|
self.input_size = input_dim
|
|
|
|
self.n_actions = n_actions
|
|
|
|
# Deep network architecture optimized for trading features
|
|
self.network = nn.Sequential(
|
|
# Input layer
|
|
nn.Linear(self.input_size, 2048),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
|
|
# Hidden layers with residual-like connections
|
|
nn.Linear(2048, 1024),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
|
|
nn.Linear(1024, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
|
|
# Output layer for Q-values
|
|
nn.Linear(128, n_actions)
|
|
)
|
|
|
|
# Initialize weights
|
|
self._initialize_weights()
|
|
|
|
def _initialize_weights(self):
|
|
"""Initialize network weights using Xavier initialization"""
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass through the network"""
|
|
# Ensure input is properly shaped
|
|
if x.dim() > 2:
|
|
x = x.view(x.size(0), -1) # Flatten if needed
|
|
elif x.dim() == 1:
|
|
x = x.unsqueeze(0) # Add batch dimension if needed
|
|
|
|
return self.network(x)
|
|
|
|
def act(self, state, explore=True):
|
|
"""
|
|
Select action using epsilon-greedy policy
|
|
|
|
Args:
|
|
state: Current state (numpy array or tensor)
|
|
explore: Whether to use epsilon-greedy exploration
|
|
|
|
Returns:
|
|
action_idx: Selected action index
|
|
confidence: Confidence score
|
|
action_probs: Action probabilities
|
|
"""
|
|
# Convert state to tensor if needed
|
|
if isinstance(state, np.ndarray):
|
|
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
|
|
|
# Ensure proper shape
|
|
if state.dim() == 1:
|
|
state = state.unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
q_values = self.forward(state)
|
|
|
|
# Get action probabilities using softmax
|
|
action_probs = F.softmax(q_values, dim=1)
|
|
|
|
# Select action (greedy for inference)
|
|
action_idx = torch.argmax(q_values, dim=1).item()
|
|
|
|
# Calculate confidence as max probability
|
|
confidence = float(action_probs[0, action_idx].item())
|
|
|
|
# Convert probabilities to list
|
|
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
|
|
|
return action_idx, confidence, probs_list
|
|
|
|
class DQNAgent:
|
|
"""
|
|
Deep Q-Network agent for trading
|
|
Uses Enhanced CNN model as the base network with GPU support for improved performance
|
|
"""
|
|
def __init__(self,
|
|
state_shape: Tuple[int, ...],
|
|
n_actions: int = 2,
|
|
learning_rate: float = 0.001,
|
|
epsilon: float = 1.0,
|
|
epsilon_min: float = 0.01,
|
|
epsilon_decay: float = 0.995,
|
|
buffer_size: int = 10000,
|
|
batch_size: int = 32,
|
|
target_update: int = 100,
|
|
priority_memory: bool = True,
|
|
device=None,
|
|
model_name: str = "dqn_agent",
|
|
enable_checkpoints: bool = True):
|
|
|
|
# Checkpoint management
|
|
self.model_name = model_name
|
|
self.enable_checkpoints = enable_checkpoints
|
|
self.training_integration = get_training_integration() if enable_checkpoints else None
|
|
self.episode_count = 0
|
|
self.best_reward = float('-inf')
|
|
self.reward_history = deque(maxlen=100)
|
|
self.checkpoint_frequency = 100 # Save checkpoint every 100 episodes
|
|
|
|
# Extract state dimensions
|
|
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
|
# Multi-dimensional state (like image or sequence)
|
|
self.state_dim = state_shape
|
|
else:
|
|
# 1D state
|
|
if isinstance(state_shape, tuple):
|
|
if len(state_shape) == 0:
|
|
self.state_dim = 1 # Safe default for empty tuple
|
|
else:
|
|
self.state_dim = state_shape[0]
|
|
else:
|
|
self.state_dim = state_shape
|
|
|
|
# Store parameters
|
|
self.n_actions = n_actions
|
|
self.learning_rate = learning_rate
|
|
self.epsilon = epsilon
|
|
self.epsilon_min = epsilon_min
|
|
self.epsilon_decay = epsilon_decay
|
|
self.buffer_size = buffer_size
|
|
self.batch_size = batch_size
|
|
self.target_update = target_update
|
|
|
|
# Set device for computation (default to GPU if available)
|
|
if device is None:
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
else:
|
|
self.device = device
|
|
|
|
# Initialize models with RL-specific network architecture
|
|
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
|
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
|
|
|
# Initialize the target network with the same weights as the policy network
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Set models to eval mode (important for batch norm, dropout)
|
|
self.target_net.eval()
|
|
|
|
# Optimization components
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
# Experience replay memory
|
|
self.memory = []
|
|
self.positive_memory = [] # Special memory for storing good experiences
|
|
self.update_count = 0
|
|
|
|
# Extrema detection tracking
|
|
self.last_extrema_pred = {
|
|
'class': 2, # Default to "neither" (not extrema)
|
|
'confidence': 0.0,
|
|
'raw': None
|
|
}
|
|
self.extrema_memory = []
|
|
|
|
# DQN hyperparameters
|
|
self.gamma = 0.99 # Discount factor
|
|
|
|
# Initialize avg_reward for dashboard compatibility
|
|
self.avg_reward = 0.0 # Average reward tracking for dashboard
|
|
|
|
# Market regime adaptation weights
|
|
self.market_regime_weights = {
|
|
'trending': 1.0,
|
|
'sideways': 0.8,
|
|
'volatile': 1.2,
|
|
'bullish': 1.1,
|
|
'bearish': 1.1
|
|
}
|
|
|
|
# Load best checkpoint if available
|
|
if self.enable_checkpoints:
|
|
self.load_best_checkpoint()
|
|
|
|
logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}")
|
|
if enable_checkpoints:
|
|
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
|
|
|
|
# Add this line to the __init__ method
|
|
self.recent_actions = deque(maxlen=10)
|
|
self.recent_prices = deque(maxlen=20)
|
|
self.recent_rewards = deque(maxlen=100)
|
|
|
|
# Price prediction tracking
|
|
self.last_price_pred = {
|
|
'immediate': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
},
|
|
'midterm': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
},
|
|
'longterm': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
}
|
|
}
|
|
|
|
# Store separate memory for price direction examples
|
|
self.price_movement_memory = [] # For storing examples of clear price movements
|
|
|
|
# Performance tracking
|
|
self.losses = []
|
|
self.no_improvement_count = 0
|
|
|
|
# Confidence tracking
|
|
self.confidence_history = []
|
|
self.avg_confidence = 0.0
|
|
self.max_confidence = 0.0
|
|
self.min_confidence = 1.0
|
|
|
|
# Enhanced features from EnhancedDQNAgent
|
|
# Market adaptation capabilities
|
|
self.market_regime_weights = {
|
|
'trending': 1.2, # Higher confidence in trending markets
|
|
'ranging': 0.8, # Lower confidence in ranging markets
|
|
'volatile': 0.6 # Much lower confidence in volatile markets
|
|
}
|
|
|
|
# Dueling network support (requires enhanced network architecture)
|
|
self.use_dueling = True
|
|
|
|
# Prioritized experience replay parameters
|
|
self.use_prioritized_replay = priority_memory
|
|
self.alpha = 0.6 # Priority exponent
|
|
self.beta = 0.4 # Importance sampling exponent
|
|
self.beta_increment = 0.001
|
|
|
|
# Double DQN support
|
|
self.use_double_dqn = True
|
|
|
|
# Enhanced training features from EnhancedDQNAgent
|
|
self.target_update_freq = target_update # More descriptive name
|
|
self.training_steps = 0
|
|
self.gradient_clip_norm = 1.0 # Gradient clipping
|
|
|
|
# Enhanced statistics tracking
|
|
self.epsilon_history = []
|
|
self.td_errors = [] # Track TD errors for analysis
|
|
|
|
# Trade action fee and confidence thresholds
|
|
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
|
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
|
|
|
# Violent move detection
|
|
self.price_history = []
|
|
self.volatility_window = 20 # Window size for volatility calculation
|
|
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
|
self.post_violent_move = False # Flag for recent violent move
|
|
self.violent_move_cooldown = 0 # Cooldown after violent move
|
|
|
|
# Feature integration
|
|
self.last_hidden_features = None # Store last extracted features
|
|
self.feature_history = [] # Store history of features for analysis
|
|
|
|
# Real-time tick features integration
|
|
self.realtime_tick_features = None # Latest tick features from tick processor
|
|
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
|
|
|
# Check if mixed precision training should be used
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
|
self.use_mixed_precision = True
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
logger.info("Mixed precision training enabled")
|
|
else:
|
|
self.use_mixed_precision = False
|
|
logger.info("Mixed precision training disabled")
|
|
|
|
# Track if we're in training mode
|
|
self.training = True
|
|
|
|
# For compatibility with old code
|
|
self.state_size = np.prod(state_shape)
|
|
self.action_size = n_actions
|
|
self.memory_size = buffer_size
|
|
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
|
|
|
|
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
|
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
|
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
|
|
|
# Log model parameters
|
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
|
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
|
|
|
# Position management for 2-action system
|
|
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
|
self.position_entry_price = 0.0
|
|
self.position_entry_time = None
|
|
|
|
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
|
|
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
|
|
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
|
|
self.uncertainty_threshold = 0.1 # When to stay neutral
|
|
|
|
def load_best_checkpoint(self):
|
|
"""Load the best checkpoint for this DQN agent"""
|
|
try:
|
|
if not self.enable_checkpoints:
|
|
return
|
|
|
|
result = load_best_checkpoint(self.model_name)
|
|
if result:
|
|
file_path, metadata = result
|
|
checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
|
|
|
|
# Load model states
|
|
if 'policy_net_state_dict' in checkpoint:
|
|
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
|
|
if 'target_net_state_dict' in checkpoint:
|
|
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
|
|
if 'optimizer_state_dict' in checkpoint:
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
# Load training state
|
|
if 'episode_count' in checkpoint:
|
|
self.episode_count = checkpoint['episode_count']
|
|
if 'epsilon' in checkpoint:
|
|
self.epsilon = checkpoint['epsilon']
|
|
if 'best_reward' in checkpoint:
|
|
self.best_reward = checkpoint['best_reward']
|
|
|
|
logger.info(f"Loaded DQN checkpoint: {metadata.checkpoint_id}")
|
|
logger.info(f"Episode: {self.episode_count}, Best reward: {self.best_reward:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
|
|
|
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
|
|
"""Save checkpoint if performance improved or forced"""
|
|
try:
|
|
if not self.enable_checkpoints:
|
|
return False
|
|
|
|
self.episode_count += 1
|
|
self.reward_history.append(episode_reward)
|
|
|
|
# Calculate average reward over recent episodes
|
|
avg_reward = sum(self.reward_history) / len(self.reward_history)
|
|
|
|
# Update best reward
|
|
if episode_reward > self.best_reward:
|
|
self.best_reward = episode_reward
|
|
|
|
# Save checkpoint every N episodes or if forced
|
|
should_save = (
|
|
force_save or
|
|
self.episode_count % self.checkpoint_frequency == 0 or
|
|
episode_reward > self.best_reward * 0.95 # Within 5% of best
|
|
)
|
|
|
|
if should_save and self.training_integration:
|
|
return self.training_integration.save_rl_checkpoint(
|
|
rl_agent=self,
|
|
model_name=self.model_name,
|
|
episode=self.episode_count,
|
|
avg_reward=avg_reward,
|
|
best_reward=self.best_reward,
|
|
epsilon=self.epsilon,
|
|
total_pnl=0.0 # Default to 0, can be set by calling code
|
|
)
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving DQN checkpoint: {e}")
|
|
return False
|
|
|
|
# Price prediction tracking
|
|
self.last_price_pred = {
|
|
'immediate': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
},
|
|
'midterm': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
},
|
|
'longterm': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
}
|
|
}
|
|
|
|
# Store separate memory for price direction examples
|
|
self.price_movement_memory = [] # For storing examples of clear price movements
|
|
|
|
# Performance tracking
|
|
self.losses = []
|
|
self.no_improvement_count = 0
|
|
|
|
# Confidence tracking
|
|
self.confidence_history = []
|
|
self.avg_confidence = 0.0
|
|
self.max_confidence = 0.0
|
|
self.min_confidence = 1.0
|
|
|
|
# Enhanced features from EnhancedDQNAgent
|
|
# Market adaptation capabilities
|
|
self.market_regime_weights = {
|
|
'trending': 1.2, # Higher confidence in trending markets
|
|
'ranging': 0.8, # Lower confidence in ranging markets
|
|
'volatile': 0.6 # Much lower confidence in volatile markets
|
|
}
|
|
|
|
# Dueling network support (requires enhanced network architecture)
|
|
self.use_dueling = True
|
|
|
|
# Prioritized experience replay parameters
|
|
self.use_prioritized_replay = priority_memory
|
|
self.alpha = 0.6 # Priority exponent
|
|
self.beta = 0.4 # Importance sampling exponent
|
|
self.beta_increment = 0.001
|
|
|
|
# Double DQN support
|
|
self.use_double_dqn = True
|
|
|
|
# Enhanced training features from EnhancedDQNAgent
|
|
self.target_update_freq = target_update # More descriptive name
|
|
self.training_steps = 0
|
|
self.gradient_clip_norm = 1.0 # Gradient clipping
|
|
|
|
# Enhanced statistics tracking
|
|
self.epsilon_history = []
|
|
self.td_errors = [] # Track TD errors for analysis
|
|
|
|
# Trade action fee and confidence thresholds
|
|
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
|
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
|
|
|
# Violent move detection
|
|
self.price_history = []
|
|
self.volatility_window = 20 # Window size for volatility calculation
|
|
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
|
self.post_violent_move = False # Flag for recent violent move
|
|
self.violent_move_cooldown = 0 # Cooldown after violent move
|
|
|
|
# Feature integration
|
|
self.last_hidden_features = None # Store last extracted features
|
|
self.feature_history = [] # Store history of features for analysis
|
|
|
|
# Real-time tick features integration
|
|
self.realtime_tick_features = None # Latest tick features from tick processor
|
|
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
|
|
|
# Check if mixed precision training should be used
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
|
self.use_mixed_precision = True
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
logger.info("Mixed precision training enabled")
|
|
else:
|
|
self.use_mixed_precision = False
|
|
logger.info("Mixed precision training disabled")
|
|
|
|
# Track if we're in training mode
|
|
self.training = True
|
|
|
|
# For compatibility with old code
|
|
self.state_size = np.prod(state_shape)
|
|
self.action_size = n_actions
|
|
self.memory_size = buffer_size
|
|
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
|
|
|
|
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
|
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
|
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
|
|
|
# Log model parameters
|
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
|
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
|
|
|
# Position management for 2-action system
|
|
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
|
self.position_entry_price = 0.0
|
|
self.position_entry_time = None
|
|
|
|
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
|
|
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
|
|
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
|
|
self.uncertainty_threshold = 0.1 # When to stay neutral
|
|
|
|
def move_models_to_device(self, device=None):
|
|
"""Move models to the specified device (GPU/CPU)"""
|
|
if device is not None:
|
|
self.device = device
|
|
|
|
try:
|
|
self.policy_net = self.policy_net.to(self.device)
|
|
self.target_net = self.target_net.to(self.device)
|
|
logger.info(f"Moved models to {self.device}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
|
return False
|
|
|
|
def to(self, device):
|
|
"""PyTorch-style device movement method"""
|
|
self.device = device
|
|
self.policy_net = self.policy_net.to(device)
|
|
self.target_net = self.target_net.to(device)
|
|
return self
|
|
|
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
|
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
|
"""
|
|
Store experience in memory with prioritization
|
|
|
|
Args:
|
|
state: Current state
|
|
action: Action taken
|
|
reward: Reward received
|
|
next_state: Next state
|
|
done: Whether episode is done
|
|
is_extrema: Whether this is a local extrema sample (for specialized learning)
|
|
"""
|
|
experience = (state, action, reward, next_state, done)
|
|
|
|
# Always add to main memory
|
|
self.memory.append(experience)
|
|
|
|
# Try to extract price change to analyze the experience
|
|
try:
|
|
# Extract price feature from sequence data (if available)
|
|
if len(state.shape) > 1: # 2D state [timeframes, features]
|
|
current_price = state[-1, -1] # Last timeframe, last feature
|
|
next_price = next_state[-1, -1]
|
|
else: # 1D state
|
|
current_price = state[-1] # Last feature
|
|
next_price = next_state[-1]
|
|
|
|
# Calculate price change - avoid division by zero
|
|
if np.isscalar(current_price) and current_price != 0:
|
|
price_change = (next_price - current_price) / current_price
|
|
elif isinstance(current_price, np.ndarray):
|
|
# Handle array case - protect against division by zero
|
|
with np.errstate(divide='ignore', invalid='ignore'):
|
|
price_change = (next_price - current_price) / current_price
|
|
# Replace infinities and NaNs with zeros
|
|
if isinstance(price_change, np.ndarray):
|
|
price_change = np.nan_to_num(price_change, nan=0.0, posinf=0.0, neginf=0.0)
|
|
else:
|
|
price_change = 0.0 if np.isnan(price_change) or np.isinf(price_change) else price_change
|
|
else:
|
|
price_change = 0.0
|
|
|
|
# Check if this is a significant price movement
|
|
if abs(price_change) > 0.002: # Significant price change
|
|
# Store in price movement memory
|
|
self.price_movement_memory.append(experience)
|
|
|
|
# Log significant price movements
|
|
direction = "UP" if price_change > 0 else "DOWN"
|
|
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
|
|
|
|
# For clear price movements, also duplicate in main memory to learn more
|
|
if abs(price_change) > 0.005: # Very significant movement
|
|
for _ in range(2): # Add 2 extra copies
|
|
self.memory.append(experience)
|
|
except Exception as e:
|
|
# Skip price movement analysis if it fails
|
|
pass
|
|
|
|
# Check if this is an extrema point based on our extrema detection head
|
|
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
|
|
# Class 0 = bottom, 1 = top, 2 = neither
|
|
# Only consider high confidence predictions
|
|
if self.last_extrema_pred['confidence'] > 0.7:
|
|
self.extrema_memory.append(experience)
|
|
|
|
# Log this special experience
|
|
extrema_type = "BOTTOM" if self.last_extrema_pred['class'] == 0 else "TOP"
|
|
logger.info(f"Stored {extrema_type} experience with reward {reward:.4f}")
|
|
|
|
# For tops and bottoms, also duplicate the experience in memory to learn more from it
|
|
for _ in range(2): # Add 2 extra copies
|
|
self.memory.append(experience)
|
|
|
|
# Explicitly marked extrema points also go to extrema memory
|
|
elif is_extrema:
|
|
self.extrema_memory.append(experience)
|
|
|
|
# Store positive experiences separately for prioritized replay
|
|
if reward > 0:
|
|
self.positive_memory.append(experience)
|
|
|
|
# For very good rewards, duplicate to learn more from them
|
|
if reward > 0.1:
|
|
for _ in range(min(int(reward * 10), 5)): # Cap at 5 extra copies for very high rewards
|
|
self.positive_memory.append(experience)
|
|
|
|
# Keep memory size under control
|
|
if len(self.memory) > self.buffer_size:
|
|
# Keep more recent experiences
|
|
self.memory = self.memory[-self.buffer_size:]
|
|
|
|
# Keep specialized memories under control too
|
|
if len(self.positive_memory) > self.buffer_size // 4:
|
|
self.positive_memory = self.positive_memory[-(self.buffer_size // 4):]
|
|
|
|
if len(self.extrema_memory) > self.buffer_size // 4:
|
|
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
|
|
|
|
if len(self.price_movement_memory) > self.buffer_size // 4:
|
|
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
|
|
|
def act(self, state: np.ndarray, explore=True, current_price=None, market_context=None) -> int:
|
|
"""
|
|
Choose action based on current state using 2-action system with intelligent position management
|
|
|
|
Args:
|
|
state: Current market state
|
|
explore: Whether to use epsilon-greedy exploration
|
|
current_price: Current market price for position management
|
|
market_context: Additional market context for decision making
|
|
|
|
Returns:
|
|
int: Action (0=BUY, 1=SELL)
|
|
"""
|
|
try:
|
|
# Use the DQNNetwork's act method for consistent behavior
|
|
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
|
|
|
# Apply epsilon-greedy exploration if requested
|
|
if explore and np.random.random() <= self.epsilon:
|
|
action_idx = np.random.choice(self.n_actions)
|
|
|
|
# Update tracking
|
|
if current_price:
|
|
self.recent_prices.append(current_price)
|
|
|
|
self.recent_actions.append(action_idx)
|
|
return action_idx
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in act method: {e}")
|
|
# Return default action (HOLD/SELL)
|
|
return 1
|
|
|
|
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
|
"""Choose action with confidence score adapted to market regime"""
|
|
try:
|
|
# Use the DQNNetwork's act method which handles the state properly
|
|
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
|
|
|
# Adapt confidence based on market regime
|
|
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
|
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
|
|
|
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
|
return int(action_idx), float(adapted_confidence), action_probs
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in act_with_confidence: {e}")
|
|
# Return default action with low confidence
|
|
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
|
|
|
|
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
|
"""
|
|
Determine action based on current position and confidence thresholds
|
|
|
|
This implements the intelligent position management where:
|
|
- When neutral: Need high confidence to enter position
|
|
- When in position: Need lower confidence to exit
|
|
- Different thresholds for entry vs exit
|
|
"""
|
|
|
|
# Apply epsilon-greedy exploration
|
|
if explore and np.random.random() <= self.epsilon:
|
|
return np.random.choice([0, 1])
|
|
|
|
# Get the dominant signal - FIXED ACTION MAPPING: 0=BUY, 1=SELL
|
|
dominant_action = 0 if buy_conf > sell_conf else 1
|
|
dominant_confidence = max(buy_conf, sell_conf)
|
|
|
|
# Decision logic based on current position
|
|
if self.current_position == 0: # No position - need high confidence to enter
|
|
if dominant_confidence >= self.entry_confidence_threshold:
|
|
# Strong enough signal to enter position
|
|
if dominant_action == 0: # BUY signal (action 0)
|
|
self.current_position = 1.0
|
|
self.position_entry_price = current_price
|
|
self.position_entry_time = time.time()
|
|
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
|
return 0 # Return BUY action (0)
|
|
else: # SELL signal (action 1)
|
|
self.current_position = -1.0
|
|
self.position_entry_price = current_price
|
|
self.position_entry_time = time.time()
|
|
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
|
return 1 # Return SELL action (1)
|
|
else:
|
|
# Not confident enough to enter position
|
|
return None
|
|
|
|
elif self.current_position > 0: # Long position
|
|
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
|
|
# SELL signal (action 1) with enough confidence to close long position
|
|
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
|
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
|
self.current_position = 0.0
|
|
self.position_entry_price = 0.0
|
|
self.position_entry_time = None
|
|
return 1 # Return SELL action (1)
|
|
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
|
|
# Very strong SELL signal - close long and enter short
|
|
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
|
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
|
self.current_position = -1.0
|
|
self.position_entry_price = current_price
|
|
self.position_entry_time = time.time()
|
|
return 1 # Return SELL action (1)
|
|
else:
|
|
# Hold the long position
|
|
return None
|
|
|
|
elif self.current_position < 0: # Short position
|
|
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
|
|
# BUY signal (action 0) with enough confidence to close short position
|
|
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
|
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
|
self.current_position = 0.0
|
|
self.position_entry_price = 0.0
|
|
self.position_entry_time = None
|
|
return 0 # Return BUY action (0)
|
|
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
|
|
# Very strong BUY signal - close short and enter long
|
|
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
|
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
|
self.current_position = 1.0
|
|
self.position_entry_price = current_price
|
|
self.position_entry_time = time.time()
|
|
return 0 # Return BUY action (0)
|
|
else:
|
|
# Hold the short position
|
|
return None
|
|
|
|
return None
|
|
|
|
def _safe_cnn_forward(self, network, states):
|
|
"""Safely call CNN forward method ensuring we always get 5 return values"""
|
|
try:
|
|
result = network(states)
|
|
if isinstance(result, tuple) and len(result) == 5:
|
|
return result
|
|
elif isinstance(result, tuple) and len(result) == 1:
|
|
# Handle case where only q_values are returned (like in empty tensor case)
|
|
q_values = result[0]
|
|
batch_size = q_values.size(0)
|
|
device = q_values.device
|
|
default_extrema = torch.zeros(batch_size, 3, device=device)
|
|
default_price = torch.zeros(batch_size, 1, device=device)
|
|
default_features = torch.zeros(batch_size, 1024, device=device)
|
|
default_advanced = torch.zeros(batch_size, 1, device=device)
|
|
return q_values, default_extrema, default_price, default_features, default_advanced
|
|
else:
|
|
# Fallback: create all default tensors
|
|
batch_size = states.size(0)
|
|
device = states.device
|
|
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
|
|
default_extrema = torch.zeros(batch_size, 3, device=device)
|
|
default_price = torch.zeros(batch_size, 1, device=device)
|
|
default_features = torch.zeros(batch_size, 1024, device=device)
|
|
default_advanced = torch.zeros(batch_size, 1, device=device)
|
|
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN forward pass: {e}")
|
|
# Fallback: create all default tensors
|
|
batch_size = states.size(0)
|
|
device = states.device
|
|
default_q_values = torch.zeros(batch_size, self.n_actions, device=device)
|
|
default_extrema = torch.zeros(batch_size, 3, device=device)
|
|
default_price = torch.zeros(batch_size, 1, device=device)
|
|
default_features = torch.zeros(batch_size, 1024, device=device)
|
|
default_advanced = torch.zeros(batch_size, 1, device=device)
|
|
return default_q_values, default_extrema, default_price, default_features, default_advanced
|
|
|
|
def replay(self, experiences=None):
|
|
"""Train the model using experiences from memory"""
|
|
|
|
# Don't train if not in training mode
|
|
if not self.training:
|
|
return 0.0
|
|
|
|
# If no experiences provided, sample from memory
|
|
if experiences is None:
|
|
# Skip if memory is too small
|
|
if len(self.memory) < self.batch_size:
|
|
return 0.0
|
|
|
|
# Sample random mini-batch from memory
|
|
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
|
experiences = [self.memory[i] for i in indices]
|
|
|
|
# Validate experiences before processing
|
|
if not experiences or len(experiences) == 0:
|
|
logger.warning("No experiences provided for training")
|
|
return 0.0
|
|
|
|
# Sanitize and validate experiences
|
|
valid_experiences = []
|
|
for i, exp in enumerate(experiences):
|
|
try:
|
|
if len(exp) != 5:
|
|
logger.debug(f"Invalid experience format at index {i}: expected 5 elements, got {len(exp)}")
|
|
continue
|
|
|
|
state, action, reward, next_state, done = exp
|
|
|
|
# Validate state
|
|
state = self._validate_and_fix_state(state)
|
|
next_state = self._validate_and_fix_state(next_state)
|
|
|
|
if state is None or next_state is None:
|
|
continue
|
|
|
|
# Validate action
|
|
if isinstance(action, dict):
|
|
action = action.get('action', action.get('value', 0))
|
|
action = int(action) if action is not None else 0
|
|
action = max(0, min(action, self.n_actions - 1)) # Clamp to valid range
|
|
|
|
# Validate reward
|
|
if isinstance(reward, dict):
|
|
reward = reward.get('reward', reward.get('value', 0.0))
|
|
reward = float(reward) if reward is not None else 0.0
|
|
|
|
# Validate done flag
|
|
done = bool(done) if done is not None else False
|
|
|
|
valid_experiences.append((state, action, reward, next_state, done))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error processing experience {i}: {e}")
|
|
continue
|
|
|
|
if len(valid_experiences) == 0:
|
|
logger.warning("No valid experiences after sanitization")
|
|
return 0.0
|
|
|
|
# Use validated experiences for training
|
|
experiences = valid_experiences
|
|
|
|
# Extract components
|
|
states, actions, rewards, next_states, dones = zip(*experiences)
|
|
|
|
# Convert to tensors with proper validation
|
|
try:
|
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
|
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
|
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
|
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
|
|
|
# Final validation of tensor shapes
|
|
if states.shape[0] == 0 or actions.shape[0] == 0:
|
|
logger.warning("Empty tensors after conversion")
|
|
return 0.0
|
|
|
|
# Ensure all tensors have the same batch size
|
|
batch_size = states.shape[0]
|
|
if not all(tensor.shape[0] == batch_size for tensor in [actions, rewards, next_states, dones]):
|
|
logger.warning("Inconsistent batch sizes across tensors")
|
|
return 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting experiences to tensors: {e}")
|
|
return 0.0
|
|
|
|
# Choose training method based on precision mode
|
|
if self.use_mixed_precision:
|
|
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
|
|
else:
|
|
loss = self._replay_standard(states, actions, rewards, next_states, dones)
|
|
|
|
# Update epsilon
|
|
if self.epsilon > self.epsilon_min:
|
|
self.epsilon *= self.epsilon_decay
|
|
|
|
# Update statistics
|
|
self.losses.append(loss)
|
|
if len(self.losses) > 1000:
|
|
self.losses = self.losses[-500:] # Keep only recent losses
|
|
|
|
return loss
|
|
|
|
def _validate_and_fix_state(self, state):
|
|
"""Validate and fix state to ensure it has correct dimensions and no empty data"""
|
|
try:
|
|
# Convert to numpy if needed
|
|
if isinstance(state, torch.Tensor):
|
|
state = state.detach().cpu().numpy()
|
|
elif not isinstance(state, np.ndarray):
|
|
state = np.array(state, dtype=np.float32)
|
|
|
|
# Flatten if multi-dimensional
|
|
if state.ndim > 1:
|
|
state = state.flatten()
|
|
|
|
# Check for empty or invalid state
|
|
if state.size == 0:
|
|
logger.warning("Empty state detected, using default")
|
|
expected_size = getattr(self, 'state_size', 403)
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
return np.zeros(int(expected_size), dtype=np.float32)
|
|
|
|
# Check for NaN or infinite values
|
|
if np.any(np.isnan(state)) or np.any(np.isinf(state)):
|
|
logger.warning("NaN or infinite values in state, replacing with zeros")
|
|
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
|
|
# Ensure correct dimensions
|
|
expected_size = getattr(self, 'state_size', 403)
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
expected_size = int(expected_size)
|
|
|
|
if len(state) != expected_size:
|
|
if len(state) < expected_size:
|
|
# Pad with zeros
|
|
padded_state = np.zeros(expected_size, dtype=np.float32)
|
|
padded_state[:len(state)] = state
|
|
state = padded_state
|
|
else:
|
|
# Truncate
|
|
state = state[:expected_size]
|
|
|
|
return state.astype(np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating state: {e}")
|
|
# Return default state as fallback
|
|
expected_size = getattr(self, 'state_size', 403)
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
return np.zeros(int(expected_size), dtype=np.float32)
|
|
|
|
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
|
"""Standard training step without mixed precision"""
|
|
try:
|
|
# Validate input tensors
|
|
if states.shape[0] == 0:
|
|
logger.warning("Empty batch in _replay_standard")
|
|
return 0.0
|
|
|
|
# Get current Q values using safe wrapper
|
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Enhanced Double DQN implementation
|
|
with torch.no_grad():
|
|
if self.use_double_dqn:
|
|
# Double DQN: Use policy network to select actions, target network to evaluate
|
|
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
|
next_actions = policy_q_values.argmax(1)
|
|
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
|
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
|
else:
|
|
# Standard DQN: Use target network for both selection and evaluation
|
|
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
|
|
# Ensure tensor shapes are consistent
|
|
batch_size = states.shape[0]
|
|
if rewards.shape[0] != batch_size or next_q_values.shape[0] != batch_size:
|
|
logger.warning(f"Shape mismatch in replay: batch_size={batch_size}, rewards={rewards.shape}, next_q_values={next_q_values.shape}")
|
|
min_size = min(batch_size, rewards.shape[0], next_q_values.shape[0])
|
|
rewards = rewards[:min_size]
|
|
dones = dones[:min_size]
|
|
next_q_values = next_q_values[:min_size]
|
|
current_q_values = current_q_values[:min_size]
|
|
|
|
# Calculate target Q values
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Compute loss for Q value - ensure tensors require gradients
|
|
if not current_q_values.requires_grad:
|
|
logger.warning("Current Q values do not require gradients")
|
|
return 0.0
|
|
|
|
q_loss = self.criterion(current_q_values, target_q_values.detach())
|
|
|
|
# Initialize total loss with Q loss
|
|
total_loss = q_loss
|
|
|
|
# Add auxiliary losses if available and valid
|
|
try:
|
|
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
|
|
# Create simple extrema targets based on Q-values
|
|
with torch.no_grad():
|
|
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2 # Default to "neither"
|
|
|
|
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
|
total_loss = total_loss + 0.1 * extrema_loss
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not calculate auxiliary loss: {e}")
|
|
|
|
# Reset gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
# Ensure total loss requires gradients
|
|
if not total_loss.requires_grad:
|
|
logger.warning("Total loss does not require gradients - policy network may not be in training mode")
|
|
self.policy_net.train() # Ensure training mode
|
|
return 0.0
|
|
|
|
# Backward pass
|
|
total_loss.backward()
|
|
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
# Check if gradients are valid
|
|
has_valid_gradients = False
|
|
for param in self.policy_net.parameters():
|
|
if param.grad is not None and torch.any(torch.isfinite(param.grad)):
|
|
has_valid_gradients = True
|
|
break
|
|
|
|
if not has_valid_gradients:
|
|
logger.warning("No valid gradients found, skipping optimizer step")
|
|
return 0.0
|
|
|
|
# Update weights
|
|
self.optimizer.step()
|
|
|
|
# Update target network periodically
|
|
self.training_steps += 1
|
|
if self.training_steps % self.target_update_freq == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
logger.debug(f"Target network updated at step {self.training_steps}")
|
|
|
|
return total_loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in standard replay: {e}")
|
|
return 0.0
|
|
|
|
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
|
|
"""Mixed precision training step"""
|
|
if not self.use_mixed_precision:
|
|
logger.warning("Mixed precision not available, falling back to standard replay")
|
|
return self._replay_standard(states, actions, rewards, next_states, dones)
|
|
|
|
try:
|
|
# Validate input tensors
|
|
if states.shape[0] == 0:
|
|
logger.warning("Empty batch in _replay_mixed_precision")
|
|
return 0.0
|
|
|
|
# Zero gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
# Forward pass with amp autocasting
|
|
import warnings
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", FutureWarning)
|
|
with torch.cuda.amp.autocast():
|
|
# Get current Q values and predictions
|
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Get next Q values from target network
|
|
with torch.no_grad():
|
|
if self.use_double_dqn:
|
|
# Double DQN
|
|
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
|
next_actions = policy_q_values.argmax(1)
|
|
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
|
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
|
else:
|
|
# Standard DQN
|
|
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
|
|
# Ensure consistent shapes
|
|
batch_size = states.shape[0]
|
|
if rewards.shape[0] != batch_size or next_q_values.shape[0] != batch_size:
|
|
logger.warning(f"Shape mismatch in mixed precision replay")
|
|
min_size = min(batch_size, rewards.shape[0], next_q_values.shape[0])
|
|
rewards = rewards[:min_size]
|
|
dones = dones[:min_size]
|
|
next_q_values = next_q_values[:min_size]
|
|
current_q_values = current_q_values[:min_size]
|
|
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Compute Q-value loss (primary task)
|
|
q_loss = nn.MSELoss()(current_q_values, target_q_values.detach())
|
|
|
|
# Initialize loss with q_loss
|
|
loss = q_loss
|
|
|
|
# Add auxiliary losses if available
|
|
try:
|
|
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
|
|
# Simple extrema targets
|
|
with torch.no_grad():
|
|
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2
|
|
|
|
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
|
loss = loss + 0.1 * extrema_loss
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not add auxiliary loss in mixed precision: {e}")
|
|
|
|
# Check if loss requires gradients
|
|
if not loss.requires_grad:
|
|
logger.warning("Loss does not require gradients in mixed precision training")
|
|
return 0.0
|
|
|
|
# Scale and backward pass
|
|
self.scaler.scale(loss).backward()
|
|
|
|
# Unscale gradients and clip
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
# Check for valid gradients
|
|
has_valid_gradients = False
|
|
for param in self.policy_net.parameters():
|
|
if param.grad is not None and torch.any(torch.isfinite(param.grad)):
|
|
has_valid_gradients = True
|
|
break
|
|
|
|
if not has_valid_gradients:
|
|
logger.warning("No valid gradients in mixed precision training")
|
|
self.scaler.update() # Still update scaler
|
|
return 0.0
|
|
|
|
# Optimizer step with scaler
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
# Update target network
|
|
self.training_steps += 1
|
|
if self.training_steps % self.target_update_freq == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
logger.debug(f"Target network updated at step {self.training_steps}")
|
|
|
|
return loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in mixed precision replay: {e}")
|
|
return 0.0
|
|
|
|
def train_on_extrema(self, states, actions, rewards, next_states, dones):
|
|
"""
|
|
Special training function specifically for extrema points
|
|
|
|
Args:
|
|
states: Batch of states at extrema points
|
|
actions: Batch of actions
|
|
rewards: Batch of rewards
|
|
next_states: Batch of next states
|
|
dones: Batch of done flags
|
|
|
|
Returns:
|
|
float: Training loss
|
|
"""
|
|
# Convert to numpy arrays if not already
|
|
if not isinstance(states, np.ndarray):
|
|
states = np.array(states)
|
|
if not isinstance(actions, np.ndarray):
|
|
actions = np.array(actions)
|
|
if not isinstance(rewards, np.ndarray):
|
|
rewards = np.array(rewards)
|
|
if not isinstance(next_states, np.ndarray):
|
|
next_states = np.array(next_states)
|
|
if not isinstance(dones, np.ndarray):
|
|
dones = np.array(dones, dtype=np.float32)
|
|
|
|
# Normalize states
|
|
states = np.vstack([self._normalize_state(s) for s in states])
|
|
next_states = np.vstack([self._normalize_state(s) for s in next_states])
|
|
|
|
# Convert to torch tensors and move to device
|
|
states_tensor = torch.FloatTensor(states).to(self.device)
|
|
actions_tensor = torch.LongTensor(actions).to(self.device)
|
|
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
|
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
|
|
dones_tensor = torch.FloatTensor(dones).to(self.device)
|
|
|
|
# Choose training method based on precision mode
|
|
if self.use_mixed_precision:
|
|
return self._replay_mixed_precision(
|
|
states_tensor, actions_tensor, rewards_tensor,
|
|
next_states_tensor, dones_tensor
|
|
)
|
|
else:
|
|
return self._replay_standard(
|
|
states_tensor, actions_tensor, rewards_tensor,
|
|
next_states_tensor, dones_tensor
|
|
)
|
|
|
|
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
|
"""Normalize the state data to prevent numerical issues"""
|
|
# Handle NaN and infinite values
|
|
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
|
|
# Check if state is 1D array (happens in some environments)
|
|
if len(state.shape) == 1:
|
|
# If 1D, we need to normalize the whole array
|
|
normalized_state = state.copy()
|
|
|
|
# Convert any timestamp or non-numeric data to float
|
|
for i in range(len(normalized_state)):
|
|
# Check for timestamp-like objects
|
|
if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')):
|
|
# Convert timestamp to float (seconds since epoch)
|
|
normalized_state[i] = float(normalized_state[i].timestamp())
|
|
elif not isinstance(normalized_state[i], (int, float, np.number)):
|
|
# Set non-numeric data to 0
|
|
normalized_state[i] = 0.0
|
|
|
|
# Ensure all values are float
|
|
normalized_state = normalized_state.astype(np.float32)
|
|
|
|
# Simple min-max normalization for 1D state
|
|
state_min = np.min(normalized_state)
|
|
state_max = np.max(normalized_state)
|
|
if state_max > state_min:
|
|
normalized_state = (normalized_state - state_min) / (state_max - state_min)
|
|
return normalized_state
|
|
|
|
# Handle 2D arrays
|
|
normalized_state = np.zeros_like(state, dtype=np.float32)
|
|
|
|
# Convert any timestamp or non-numeric data to float
|
|
for i in range(state.shape[0]):
|
|
for j in range(state.shape[1]):
|
|
if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')):
|
|
# Convert timestamp to float (seconds since epoch)
|
|
normalized_state[i, j] = float(state[i, j].timestamp())
|
|
elif isinstance(state[i, j], (int, float, np.number)):
|
|
normalized_state[i, j] = state[i, j]
|
|
else:
|
|
# Set non-numeric data to 0
|
|
normalized_state[i, j] = 0.0
|
|
|
|
# Loop through each timeframe's features in the combined state
|
|
feature_count = state.shape[1] // len(self.timeframes)
|
|
|
|
for tf_idx in range(len(self.timeframes)):
|
|
start_idx = tf_idx * feature_count
|
|
end_idx = start_idx + feature_count
|
|
|
|
# Extract this timeframe's features
|
|
tf_features = normalized_state[:, start_idx:end_idx]
|
|
|
|
# Normalize OHLCV data by the first close price in the window
|
|
# This makes price movements relative rather than absolute
|
|
price_idx = 3 # Assuming close price is at index 3
|
|
if price_idx < tf_features.shape[1]:
|
|
reference_price = np.mean(tf_features[:, price_idx])
|
|
if reference_price != 0:
|
|
# Normalize price-related columns (OHLC)
|
|
for i in range(4): # First 4 columns are OHLC
|
|
if i < tf_features.shape[1]:
|
|
normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price
|
|
|
|
# Normalize volume using mean and std
|
|
vol_idx = 4 # Assuming volume is at index 4
|
|
if vol_idx < tf_features.shape[1]:
|
|
vol_mean = np.mean(tf_features[:, vol_idx])
|
|
vol_std = np.std(tf_features[:, vol_idx])
|
|
if vol_std > 0:
|
|
normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std
|
|
else:
|
|
normalized_state[:, start_idx + vol_idx] = 0
|
|
|
|
# Other features (technical indicators) - normalize with min-max scaling
|
|
for i in range(5, feature_count):
|
|
if i < tf_features.shape[1]:
|
|
feature_min = np.min(tf_features[:, i])
|
|
feature_max = np.max(tf_features[:, i])
|
|
if feature_max > feature_min:
|
|
normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min)
|
|
else:
|
|
normalized_state[:, start_idx + i] = 0
|
|
|
|
return normalized_state
|
|
|
|
def update_realtime_tick_features(self, tick_features):
|
|
"""Update with real-time tick features from tick processor"""
|
|
try:
|
|
if tick_features is not None:
|
|
self.realtime_tick_features = tick_features
|
|
|
|
# Log high-confidence tick features
|
|
if tick_features.get('confidence', 0) > 0.8:
|
|
logger.debug(f"High-confidence tick features updated: confidence={tick_features['confidence']:.3f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating real-time tick features: {e}")
|
|
|
|
def _enhance_state_with_tick_features(self, state: np.ndarray) -> np.ndarray:
|
|
"""Enhance state with real-time tick features if available"""
|
|
try:
|
|
if self.realtime_tick_features is None:
|
|
return state
|
|
|
|
# Extract neural features from tick processor
|
|
neural_features = self.realtime_tick_features.get('neural_features', np.array([]))
|
|
volume_features = self.realtime_tick_features.get('volume_features', np.array([]))
|
|
microstructure_features = self.realtime_tick_features.get('microstructure_features', np.array([]))
|
|
confidence = self.realtime_tick_features.get('confidence', 0.0)
|
|
|
|
# Combine tick features - make them compact to match state dimensions
|
|
tick_features = np.concatenate([
|
|
neural_features[:3] if len(neural_features) >= 3 else np.zeros(3), # Take first 3 neural features
|
|
volume_features[:1] if len(volume_features) >= 1 else np.zeros(1), # Take first volume feature
|
|
microstructure_features[:1] if len(microstructure_features) >= 1 else np.zeros(1), # Take first microstructure feature
|
|
])
|
|
|
|
# Weight the tick features
|
|
weighted_tick_features = tick_features * self.tick_feature_weight
|
|
|
|
# Enhance the state by adding tick features to each timeframe
|
|
if len(state.shape) == 1:
|
|
# 1D state - append tick features
|
|
enhanced_state = np.concatenate([state, weighted_tick_features])
|
|
else:
|
|
# 2D state - add tick features to each timeframe row
|
|
num_timeframes, num_features = state.shape
|
|
|
|
# Ensure tick features match the number of original features
|
|
if len(weighted_tick_features) != num_features:
|
|
# Pad or truncate tick features to match state feature dimension
|
|
if len(weighted_tick_features) < num_features:
|
|
# Pad with zeros
|
|
padded_features = np.zeros(num_features)
|
|
padded_features[:len(weighted_tick_features)] = weighted_tick_features
|
|
weighted_tick_features = padded_features
|
|
else:
|
|
# Truncate to match
|
|
weighted_tick_features = weighted_tick_features[:num_features]
|
|
|
|
# Add tick features to the last row (most recent timeframe)
|
|
enhanced_state = state.copy()
|
|
enhanced_state[-1, :] += weighted_tick_features # Add to last timeframe
|
|
|
|
return enhanced_state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error enhancing state with tick features: {e}")
|
|
return state
|
|
|
|
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
|
|
"""Update learning metrics and perform learning rate adjustments if needed"""
|
|
# Update average reward with exponential moving average
|
|
if self.avg_reward == 0:
|
|
self.avg_reward = episode_reward
|
|
else:
|
|
self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward
|
|
|
|
# Check if we're making sufficient progress
|
|
if episode_reward > (1 + best_reward_threshold) * self.best_reward:
|
|
self.best_reward = episode_reward
|
|
self.no_improvement_count = 0
|
|
return True # Improved
|
|
else:
|
|
self.no_improvement_count += 1
|
|
|
|
# If no improvement for a while, adjust learning rate
|
|
if self.no_improvement_count >= 10:
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
new_lr = current_lr * 0.5
|
|
if new_lr >= 1e-6: # Don't reduce below minimum threshold
|
|
for param_group in self.optimizer.param_groups:
|
|
param_group['lr'] = new_lr
|
|
logger.info(f"Reducing learning rate from {current_lr} to {new_lr}")
|
|
self.no_improvement_count = 0
|
|
|
|
return False # No improvement
|
|
|
|
def save(self, path: str):
|
|
"""Save model and agent state"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
# Save policy network
|
|
self.policy_net.save(f"{path}_policy")
|
|
|
|
# Save target network
|
|
self.target_net.save(f"{path}_target")
|
|
|
|
# Save agent state
|
|
state = {
|
|
'epsilon': self.epsilon,
|
|
'update_count': self.update_count,
|
|
'losses': self.losses,
|
|
'optimizer_state': self.optimizer.state_dict(),
|
|
'best_reward': self.best_reward,
|
|
'avg_reward': self.avg_reward
|
|
}
|
|
|
|
torch.save(state, f"{path}_agent_state.pt")
|
|
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
|
|
|
def load(self, path: str):
|
|
"""Load model and agent state"""
|
|
# Load policy network
|
|
self.policy_net.load(f"{path}_policy")
|
|
|
|
# Load target network
|
|
self.target_net.load(f"{path}_target")
|
|
|
|
# Load agent state
|
|
try:
|
|
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
|
self.epsilon = agent_state['epsilon']
|
|
self.update_count = agent_state['update_count']
|
|
self.losses = agent_state['losses']
|
|
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
|
|
|
# Load additional metrics if they exist
|
|
if 'best_reward' in agent_state:
|
|
self.best_reward = agent_state['best_reward']
|
|
if 'avg_reward' in agent_state:
|
|
self.avg_reward = agent_state['avg_reward']
|
|
|
|
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
|
except FileNotFoundError:
|
|
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
|
|
|
def get_position_info(self):
|
|
"""Get current position information"""
|
|
return {
|
|
'position': self.current_position,
|
|
'entry_price': self.position_entry_price,
|
|
'entry_time': self.position_entry_time,
|
|
'entry_threshold': self.entry_confidence_threshold,
|
|
'exit_threshold': self.exit_confidence_threshold
|
|
}
|
|
|
|
def get_enhanced_training_stats(self):
|
|
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
|
|
return {
|
|
'buffer_size': len(self.memory),
|
|
'epsilon': self.epsilon,
|
|
'avg_reward': self.avg_reward,
|
|
'best_reward': self.best_reward,
|
|
'recent_rewards': list(self.recent_rewards) if hasattr(self, 'recent_rewards') else [],
|
|
'no_improvement_count': self.no_improvement_count,
|
|
# Enhanced statistics from EnhancedDQNAgent
|
|
'training_steps': self.training_steps,
|
|
'avg_td_error': np.mean(self.td_errors[-100:]) if self.td_errors else 0.0,
|
|
'recent_losses': self.losses[-10:] if self.losses else [],
|
|
'epsilon_trend': self.epsilon_history[-20:] if self.epsilon_history else [],
|
|
'specialized_buffers': {
|
|
'extrema_memory': len(self.extrema_memory),
|
|
'positive_memory': len(self.positive_memory),
|
|
'price_movement_memory': len(self.price_movement_memory)
|
|
},
|
|
'market_regime_weights': self.market_regime_weights,
|
|
'use_double_dqn': self.use_double_dqn,
|
|
'use_prioritized_replay': self.use_prioritized_replay,
|
|
'gradient_clip_norm': self.gradient_clip_norm,
|
|
'target_update_frequency': self.target_update_freq
|
|
}
|
|
|
|
def get_params_count(self):
|
|
"""Get total number of parameters in the DQN model"""
|
|
total_params = 0
|
|
for param in self.policy_net.parameters():
|
|
total_params += param.numel()
|
|
return total_params
|
|
|
|
def _sanitize_state_data(self, state):
|
|
"""Sanitize state data to ensure it's a proper numeric array"""
|
|
try:
|
|
# If state is already a numpy array, return it
|
|
if isinstance(state, np.ndarray):
|
|
# Check for empty array
|
|
if state.size == 0:
|
|
logger.warning("Received empty numpy array state. Using fallback dimensions.")
|
|
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
return np.zeros(int(expected_size), dtype=np.float32)
|
|
|
|
# Check for non-numeric data and handle it
|
|
if state.dtype == object:
|
|
# Convert object array to float array
|
|
sanitized = np.zeros_like(state, dtype=np.float32)
|
|
for i in range(state.shape[0]):
|
|
if len(state.shape) > 1:
|
|
for j in range(state.shape[1]):
|
|
sanitized[i, j] = self._extract_numeric_value(state[i, j])
|
|
else:
|
|
sanitized[i] = self._extract_numeric_value(state[i])
|
|
return sanitized
|
|
else:
|
|
return state.astype(np.float32)
|
|
|
|
# If state is a list or tuple, convert to array
|
|
elif isinstance(state, (list, tuple)):
|
|
# Check for empty list/tuple
|
|
if len(state) == 0:
|
|
logger.warning("Received empty list/tuple state. Using fallback dimensions.")
|
|
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
return np.zeros(int(expected_size), dtype=np.float32)
|
|
|
|
# Recursively sanitize each element
|
|
sanitized = []
|
|
for item in state:
|
|
if isinstance(item, (list, tuple)):
|
|
sanitized_row = []
|
|
for sub_item in item:
|
|
sanitized_row.append(self._extract_numeric_value(sub_item))
|
|
sanitized.append(sanitized_row)
|
|
else:
|
|
sanitized.append(self._extract_numeric_value(item))
|
|
|
|
result = np.array(sanitized, dtype=np.float32)
|
|
|
|
# Check if result is empty and provide fallback
|
|
if result.size == 0:
|
|
logger.warning("Sanitized state resulted in empty array. Using fallback dimensions.")
|
|
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
return np.zeros(int(expected_size), dtype=np.float32)
|
|
|
|
return result
|
|
|
|
# If state is a dict, try to extract values
|
|
elif isinstance(state, dict):
|
|
# Try to extract meaningful values from dict
|
|
values = []
|
|
for key in sorted(state.keys()): # Sort for consistency
|
|
values.append(self._extract_numeric_value(state[key]))
|
|
return np.array(values, dtype=np.float32)
|
|
|
|
# If state is a single value, make it an array
|
|
else:
|
|
return np.array([self._extract_numeric_value(state)], dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error sanitizing state data: {e}. Using zero array with expected dimensions.")
|
|
# Return a zero array as fallback with the expected state dimension
|
|
# Use the state_dim from initialization, fallback to 403 if not available
|
|
expected_size = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
|
|
if isinstance(expected_size, tuple):
|
|
expected_size = np.prod(expected_size)
|
|
return np.zeros(int(expected_size), dtype=np.float32)
|
|
|
|
def _extract_numeric_value(self, value):
|
|
"""Extract a numeric value from various data types"""
|
|
try:
|
|
# Handle None values
|
|
if value is None:
|
|
return 0.0
|
|
|
|
# Handle numeric types
|
|
if isinstance(value, (int, float, np.number)):
|
|
return float(value)
|
|
|
|
# Handle dict values
|
|
elif isinstance(value, dict):
|
|
# Try common keys for numeric data
|
|
for key in ['value', 'price', 'close', 'last', 'amount', 'quantity']:
|
|
if key in value:
|
|
return self._extract_numeric_value(value[key])
|
|
# If no common keys, try to get first numeric value
|
|
for v in value.values():
|
|
if isinstance(v, (int, float, np.number)):
|
|
return float(v)
|
|
return 0.0
|
|
|
|
# Handle string values that might be numeric
|
|
elif isinstance(value, str):
|
|
try:
|
|
return float(value)
|
|
except:
|
|
return 0.0
|
|
|
|
# Handle datetime objects
|
|
elif hasattr(value, 'timestamp'):
|
|
return float(value.timestamp())
|
|
|
|
# Handle boolean values
|
|
elif isinstance(value, bool):
|
|
return float(value)
|
|
|
|
# Handle list/tuple - take first numeric value
|
|
elif isinstance(value, (list, tuple)) and len(value) > 0:
|
|
return self._extract_numeric_value(value[0])
|
|
|
|
else:
|
|
return 0.0
|
|
|
|
except:
|
|
return 0.0 |