1171 lines
59 KiB
Python
1171 lines
59 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
from typing import Tuple, List
|
|
import os
|
|
import sys
|
|
import logging
|
|
import torch.nn.functional as F
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DQNAgent:
|
|
"""
|
|
Deep Q-Network agent for trading
|
|
Uses Enhanced CNN model as the base network with GPU support for improved performance
|
|
"""
|
|
def __init__(self,
|
|
state_shape: Tuple[int, ...],
|
|
n_actions: int,
|
|
learning_rate: float = 0.0005, # Reduced learning rate for more stability
|
|
gamma: float = 0.97, # Slightly reduced discount factor
|
|
epsilon: float = 1.0,
|
|
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
|
|
epsilon_decay: float = 0.9975, # Slower decay rate
|
|
buffer_size: int = 20000, # Increased memory size
|
|
batch_size: int = 128, # Larger batch size
|
|
target_update: int = 5, # More frequent target updates
|
|
device=None): # Device for computations
|
|
|
|
# Extract state dimensions
|
|
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
|
# Multi-dimensional state (like image or sequence)
|
|
self.state_dim = state_shape
|
|
else:
|
|
# 1D state
|
|
if isinstance(state_shape, tuple):
|
|
self.state_dim = state_shape[0]
|
|
else:
|
|
self.state_dim = state_shape
|
|
|
|
# Store parameters
|
|
self.n_actions = n_actions
|
|
self.learning_rate = learning_rate
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.epsilon_min = epsilon_min
|
|
self.epsilon_decay = epsilon_decay
|
|
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
|
|
self.buffer_size = buffer_size
|
|
self.batch_size = batch_size
|
|
self.target_update = target_update
|
|
|
|
# Set device for computation (default to GPU if available)
|
|
if device is None:
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
else:
|
|
self.device = device
|
|
|
|
# Initialize models with Enhanced CNN architecture for better performance
|
|
from NN.models.enhanced_cnn import EnhancedCNN
|
|
|
|
# Use Enhanced CNN for both policy and target networks
|
|
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
|
|
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
|
|
|
|
# Initialize the target network with the same weights as the policy network
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Set models to eval mode (important for batch norm, dropout)
|
|
self.target_net.eval()
|
|
|
|
# Optimization components
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
# Experience replay memory
|
|
self.memory = []
|
|
self.positive_memory = [] # Special memory for storing good experiences
|
|
self.update_count = 0
|
|
|
|
# Extrema detection tracking
|
|
self.last_extrema_pred = {
|
|
'class': 2, # Default to "neither" (not extrema)
|
|
'confidence': 0.0,
|
|
'raw': None
|
|
}
|
|
self.extrema_memory = [] # Special memory for storing extrema points
|
|
|
|
# Price prediction tracking
|
|
self.last_price_pred = {
|
|
'immediate': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
},
|
|
'midterm': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
},
|
|
'longterm': {
|
|
'direction': 1, # Default to "sideways"
|
|
'confidence': 0.0,
|
|
'change': 0.0
|
|
}
|
|
}
|
|
|
|
# Store separate memory for price direction examples
|
|
self.price_movement_memory = [] # For storing examples of clear price movements
|
|
|
|
# Performance tracking
|
|
self.losses = []
|
|
self.avg_reward = 0.0
|
|
self.best_reward = -float('inf')
|
|
self.no_improvement_count = 0
|
|
|
|
# Confidence tracking
|
|
self.confidence_history = []
|
|
self.avg_confidence = 0.0
|
|
self.max_confidence = 0.0
|
|
self.min_confidence = 1.0
|
|
|
|
# Trade action fee and confidence thresholds
|
|
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
|
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
|
self.recent_actions = [] # Track recent actions to avoid oscillations
|
|
|
|
# Violent move detection
|
|
self.price_history = []
|
|
self.volatility_window = 20 # Window size for volatility calculation
|
|
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
|
self.post_violent_move = False # Flag for recent violent move
|
|
self.violent_move_cooldown = 0 # Cooldown after violent move
|
|
|
|
# Feature integration
|
|
self.last_hidden_features = None # Store last extracted features
|
|
self.feature_history = [] # Store history of features for analysis
|
|
|
|
# Real-time tick features integration
|
|
self.realtime_tick_features = None # Latest tick features from tick processor
|
|
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
|
|
|
# Check if mixed precision training should be used
|
|
self.use_mixed_precision = False
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
|
self.use_mixed_precision = True
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
logger.info("Mixed precision training enabled")
|
|
else:
|
|
logger.info("Mixed precision training disabled")
|
|
|
|
# Track if we're in training mode
|
|
self.training = True
|
|
|
|
# For compatibility with old code
|
|
self.state_size = np.prod(state_shape)
|
|
self.action_size = n_actions
|
|
self.memory_size = buffer_size
|
|
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
|
|
|
|
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
|
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
|
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
|
|
|
# Log model parameters
|
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
|
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
|
|
|
def move_models_to_device(self, device=None):
|
|
"""Move models to the specified device (GPU/CPU)"""
|
|
if device is not None:
|
|
self.device = device
|
|
|
|
try:
|
|
self.policy_net = self.policy_net.to(self.device)
|
|
self.target_net = self.target_net.to(self.device)
|
|
logger.info(f"Moved models to {self.device}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
|
return False
|
|
|
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
|
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
|
"""
|
|
Store experience in memory with prioritization
|
|
|
|
Args:
|
|
state: Current state
|
|
action: Action taken
|
|
reward: Reward received
|
|
next_state: Next state
|
|
done: Whether episode is done
|
|
is_extrema: Whether this is a local extrema sample (for specialized learning)
|
|
"""
|
|
experience = (state, action, reward, next_state, done)
|
|
|
|
# Always add to main memory
|
|
self.memory.append(experience)
|
|
|
|
# Try to extract price change to analyze the experience
|
|
try:
|
|
# Extract price feature from sequence data (if available)
|
|
if len(state.shape) > 1: # 2D state [timeframes, features]
|
|
current_price = state[-1, -1] # Last timeframe, last feature
|
|
next_price = next_state[-1, -1]
|
|
else: # 1D state
|
|
current_price = state[-1] # Last feature
|
|
next_price = next_state[-1]
|
|
|
|
# Calculate price change - avoid division by zero
|
|
if np.isscalar(current_price) and current_price != 0:
|
|
price_change = (next_price - current_price) / current_price
|
|
elif isinstance(current_price, np.ndarray):
|
|
# Handle array case - protect against division by zero
|
|
with np.errstate(divide='ignore', invalid='ignore'):
|
|
price_change = (next_price - current_price) / current_price
|
|
# Replace infinities and NaNs with zeros
|
|
if isinstance(price_change, np.ndarray):
|
|
price_change = np.nan_to_num(price_change, nan=0.0, posinf=0.0, neginf=0.0)
|
|
else:
|
|
price_change = 0.0 if np.isnan(price_change) or np.isinf(price_change) else price_change
|
|
else:
|
|
price_change = 0.0
|
|
|
|
# Check if this is a significant price movement
|
|
if abs(price_change) > 0.002: # Significant price change
|
|
# Store in price movement memory
|
|
self.price_movement_memory.append(experience)
|
|
|
|
# Log significant price movements
|
|
direction = "UP" if price_change > 0 else "DOWN"
|
|
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
|
|
|
|
# For clear price movements, also duplicate in main memory to learn more
|
|
if abs(price_change) > 0.005: # Very significant movement
|
|
for _ in range(2): # Add 2 extra copies
|
|
self.memory.append(experience)
|
|
except Exception as e:
|
|
# Skip price movement analysis if it fails
|
|
pass
|
|
|
|
# Check if this is an extrema point based on our extrema detection head
|
|
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
|
|
# Class 0 = bottom, 1 = top, 2 = neither
|
|
# Only consider high confidence predictions
|
|
if self.last_extrema_pred['confidence'] > 0.7:
|
|
self.extrema_memory.append(experience)
|
|
|
|
# Log this special experience
|
|
extrema_type = "BOTTOM" if self.last_extrema_pred['class'] == 0 else "TOP"
|
|
logger.info(f"Stored {extrema_type} experience with reward {reward:.4f}")
|
|
|
|
# For tops and bottoms, also duplicate the experience in memory to learn more from it
|
|
for _ in range(2): # Add 2 extra copies
|
|
self.memory.append(experience)
|
|
|
|
# Explicitly marked extrema points also go to extrema memory
|
|
elif is_extrema:
|
|
self.extrema_memory.append(experience)
|
|
|
|
# Store positive experiences separately for prioritized replay
|
|
if reward > 0:
|
|
self.positive_memory.append(experience)
|
|
|
|
# For very good rewards, duplicate to learn more from them
|
|
if reward > 0.1:
|
|
for _ in range(min(int(reward * 10), 5)): # Cap at 5 extra copies for very high rewards
|
|
self.positive_memory.append(experience)
|
|
|
|
# Keep memory size under control
|
|
if len(self.memory) > self.buffer_size:
|
|
# Keep more recent experiences
|
|
self.memory = self.memory[-self.buffer_size:]
|
|
|
|
# Keep specialized memories under control too
|
|
if len(self.positive_memory) > self.buffer_size // 4:
|
|
self.positive_memory = self.positive_memory[-(self.buffer_size // 4):]
|
|
|
|
if len(self.extrema_memory) > self.buffer_size // 4:
|
|
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
|
|
|
|
if len(self.price_movement_memory) > self.buffer_size // 4:
|
|
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
|
|
|
def act(self, state: np.ndarray, explore=True) -> int:
|
|
"""Choose action using epsilon-greedy policy with explore flag"""
|
|
if explore and random.random() < self.epsilon:
|
|
return random.randrange(self.n_actions)
|
|
|
|
with torch.no_grad():
|
|
# Enhance state with real-time tick features
|
|
enhanced_state = self._enhance_state_with_tick_features(state)
|
|
|
|
# Ensure state is normalized before inference
|
|
state_tensor = self._normalize_state(enhanced_state)
|
|
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
|
|
|
|
# Get predictions using the policy network
|
|
self.policy_net.eval() # Set to evaluation mode for inference
|
|
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
|
|
self.policy_net.train() # Back to training mode
|
|
|
|
# Store hidden features for integration
|
|
self.last_hidden_features = hidden_features.cpu().numpy()
|
|
|
|
# Track feature history (limited size)
|
|
self.feature_history.append(hidden_features.cpu().numpy())
|
|
if len(self.feature_history) > 100:
|
|
self.feature_history = self.feature_history[-100:]
|
|
|
|
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
|
extrema_class = extrema_pred.argmax(dim=1).item()
|
|
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
|
|
|
|
# Log extrema prediction for significant signals
|
|
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
|
|
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
|
|
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
|
|
|
|
# Process price predictions
|
|
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
|
|
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
|
|
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
|
|
price_values = price_predictions['values']
|
|
|
|
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
|
|
immediate_direction = price_immediate.argmax(dim=1).item()
|
|
midterm_direction = price_midterm.argmax(dim=1).item()
|
|
longterm_direction = price_longterm.argmax(dim=1).item()
|
|
|
|
# Get confidence levels
|
|
immediate_conf = price_immediate[0, immediate_direction].item()
|
|
midterm_conf = price_midterm[0, midterm_direction].item()
|
|
longterm_conf = price_longterm[0, longterm_direction].item()
|
|
|
|
# Get predicted price change percentages
|
|
price_changes = price_values[0].tolist()
|
|
|
|
# Log significant price movement predictions
|
|
timeframes = ["1s/1m", "1h", "1d", "1w"]
|
|
directions = ["DOWN", "SIDEWAYS", "UP"]
|
|
|
|
for i, (direction, conf) in enumerate([
|
|
(immediate_direction, immediate_conf),
|
|
(midterm_direction, midterm_conf),
|
|
(longterm_direction, longterm_conf)
|
|
]):
|
|
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
|
|
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
|
|
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
|
|
|
|
# Store predictions for environment to use
|
|
self.last_extrema_pred = {
|
|
'class': extrema_class,
|
|
'confidence': extrema_confidence,
|
|
'raw': extrema_pred.cpu().numpy()
|
|
}
|
|
|
|
self.last_price_pred = {
|
|
'immediate': {
|
|
'direction': immediate_direction,
|
|
'confidence': immediate_conf,
|
|
'change': price_changes[0]
|
|
},
|
|
'midterm': {
|
|
'direction': midterm_direction,
|
|
'confidence': midterm_conf,
|
|
'change': price_changes[1]
|
|
},
|
|
'longterm': {
|
|
'direction': longterm_direction,
|
|
'confidence': longterm_conf,
|
|
'change': price_changes[2]
|
|
}
|
|
}
|
|
|
|
# Get the action with highest Q-value
|
|
action = action_probs.argmax().item()
|
|
|
|
# Calculate overall confidence in the action
|
|
q_values_softmax = F.softmax(action_probs, dim=1)[0]
|
|
action_confidence = q_values_softmax[action].item()
|
|
|
|
# Track confidence metrics
|
|
self.confidence_history.append(action_confidence)
|
|
if len(self.confidence_history) > 100:
|
|
self.confidence_history = self.confidence_history[-100:]
|
|
|
|
# Update confidence metrics
|
|
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
|
self.max_confidence = max(self.max_confidence, action_confidence)
|
|
self.min_confidence = min(self.min_confidence, action_confidence)
|
|
|
|
# Log average confidence occasionally
|
|
if random.random() < 0.01: # 1% of the time
|
|
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
|
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
|
|
|
# Track price for violent move detection
|
|
try:
|
|
# Extract current price from state (assuming it's in the last position)
|
|
if len(state.shape) > 1: # For 2D state
|
|
current_price = state[-1, -1]
|
|
else: # For 1D state
|
|
current_price = state[-1]
|
|
|
|
self.price_history.append(current_price)
|
|
if len(self.price_history) > self.volatility_window:
|
|
self.price_history = self.price_history[-self.volatility_window:]
|
|
|
|
# Detect violent price moves if we have enough price history
|
|
if len(self.price_history) >= 5:
|
|
# Calculate short-term volatility
|
|
recent_prices = self.price_history[-5:]
|
|
|
|
# Make sure we're working with scalar values, not arrays
|
|
if isinstance(recent_prices[0], np.ndarray):
|
|
# If prices are arrays, extract the last value (current price)
|
|
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
|
|
|
|
# Calculate price changes with protection against division by zero
|
|
price_changes = []
|
|
for i in range(1, len(recent_prices)):
|
|
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
|
|
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
|
price_changes.append(change)
|
|
else:
|
|
price_changes.append(0.0)
|
|
|
|
# Calculate volatility as sum of absolute price changes
|
|
volatility = sum([abs(change) for change in price_changes])
|
|
|
|
# Check if we've had a violent move
|
|
if volatility > self.volatility_threshold:
|
|
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
|
|
self.post_violent_move = True
|
|
self.violent_move_cooldown = 10 # Set cooldown period
|
|
|
|
# Handle post-violent move period
|
|
if self.post_violent_move:
|
|
if self.violent_move_cooldown > 0:
|
|
self.violent_move_cooldown -= 1
|
|
# Increase confidence threshold temporarily after violent moves
|
|
effective_threshold = self.minimum_action_confidence * 1.1
|
|
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
|
|
f"Using higher confidence threshold: {effective_threshold:.4f}")
|
|
else:
|
|
self.post_violent_move = False
|
|
logger.info("Post-violent move period ended")
|
|
except Exception as e:
|
|
logger.warning(f"Error in violent move detection: {str(e)}")
|
|
|
|
# Apply trade action fee to buy/sell actions but not to hold
|
|
# This creates a threshold that must be exceeded to justify a trade
|
|
action_values = action_probs.clone()
|
|
|
|
# If BUY or SELL, apply fee by reducing the Q-value
|
|
if action == 0 or action == 1: # BUY or SELL
|
|
# Check if confidence is above minimum threshold
|
|
effective_threshold = self.minimum_action_confidence
|
|
if self.post_violent_move:
|
|
effective_threshold *= 1.1 # Higher threshold after violent moves
|
|
|
|
if action_confidence < effective_threshold:
|
|
# If confidence is below threshold, force HOLD action
|
|
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
|
|
action = 2 # HOLD
|
|
else:
|
|
# Apply trade action fee to ensure we only trade when there's clear benefit
|
|
fee_adjusted_action_values = action_values.clone()
|
|
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
|
|
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
|
|
# Hold value remains unchanged
|
|
|
|
# Re-determine the action based on fee-adjusted values
|
|
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
|
|
|
|
# If the fee changes our decision, log this
|
|
if fee_adjusted_action != action:
|
|
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
|
|
action = fee_adjusted_action
|
|
|
|
# Adjust action based on extrema and price predictions
|
|
# Prioritize short-term movement for trading decisions
|
|
if immediate_conf > 0.8: # Only adjust for strong signals
|
|
if immediate_direction == 2: # UP prediction
|
|
# Bias toward BUY for strong up predictions
|
|
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
|
|
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
|
action = 0 # BUY
|
|
elif immediate_direction == 0: # DOWN prediction
|
|
# Bias toward SELL for strong down predictions
|
|
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
|
|
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
|
action = 1 # SELL
|
|
|
|
# Also consider extrema detection for action adjustment
|
|
if extrema_confidence > 0.8: # Only adjust for strong signals
|
|
if extrema_class == 0: # Bottom detected
|
|
# Bias toward BUY at bottoms
|
|
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
|
logger.info(f"Adjusting action to BUY based on bottom detection")
|
|
action = 0 # BUY
|
|
elif extrema_class == 1: # Top detected
|
|
# Bias toward SELL at tops
|
|
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
|
logger.info(f"Adjusting action to SELL based on top detection")
|
|
action = 1 # SELL
|
|
|
|
# Finally, avoid action oscillation by checking recent history
|
|
if len(self.recent_actions) >= 2:
|
|
last_action = self.recent_actions[-1]
|
|
if action != last_action and action != 2 and last_action != 2:
|
|
# We're switching between BUY and SELL too quickly
|
|
# Only allow this if we have very high confidence
|
|
if action_confidence < 0.85:
|
|
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
|
|
action = 2 # HOLD
|
|
|
|
# Update recent actions list
|
|
self.recent_actions.append(action)
|
|
if len(self.recent_actions) > 5:
|
|
self.recent_actions = self.recent_actions[-5:]
|
|
|
|
return action
|
|
|
|
def replay(self, experiences=None):
|
|
"""Train the model using experiences from memory"""
|
|
|
|
# Don't train if not in training mode
|
|
if not self.training:
|
|
return 0.0
|
|
|
|
# If no experiences provided, sample from memory
|
|
if experiences is None:
|
|
# Skip if memory is too small
|
|
if len(self.memory) < self.batch_size:
|
|
return 0.0
|
|
|
|
# Sample random mini-batch from memory
|
|
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
|
experiences = [self.memory[i] for i in indices]
|
|
|
|
# Choose appropriate replay method
|
|
if self.use_mixed_precision:
|
|
# Convert experiences to tensors for mixed precision
|
|
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
|
|
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
|
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
|
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
|
|
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
|
|
|
# Use mixed precision replay
|
|
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
|
|
else:
|
|
# Pass experiences directly to standard replay method
|
|
loss = self._replay_standard(experiences)
|
|
|
|
# Store loss for monitoring
|
|
self.losses.append(loss)
|
|
|
|
# Track and decay epsilon
|
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
|
|
|
# Randomly decide if we should train on extrema points from special memory
|
|
if random.random() < 0.3 and len(self.extrema_memory) >= self.batch_size:
|
|
# Train specifically on extrema memory examples
|
|
extrema_indices = np.random.choice(len(self.extrema_memory), size=min(self.batch_size, len(self.extrema_memory)), replace=False)
|
|
extrema_batch = [self.extrema_memory[i] for i in extrema_indices]
|
|
|
|
# Extract tensors from extrema batch
|
|
extrema_states = torch.FloatTensor(np.array([e[0] for e in extrema_batch])).to(self.device)
|
|
extrema_actions = torch.LongTensor(np.array([e[1] for e in extrema_batch])).to(self.device)
|
|
extrema_rewards = torch.FloatTensor(np.array([e[2] for e in extrema_batch])).to(self.device)
|
|
extrema_next_states = torch.FloatTensor(np.array([e[3] for e in extrema_batch])).to(self.device)
|
|
extrema_dones = torch.FloatTensor(np.array([e[4] for e in extrema_batch])).to(self.device)
|
|
|
|
# Use a slightly reduced learning rate for extrema training
|
|
old_lr = self.optimizer.param_groups[0]['lr']
|
|
self.optimizer.param_groups[0]['lr'] = old_lr * 0.8
|
|
|
|
# Train on extrema memory
|
|
if self.use_mixed_precision:
|
|
extrema_loss = self._replay_mixed_precision(extrema_states, extrema_actions, extrema_rewards, extrema_next_states, extrema_dones)
|
|
else:
|
|
extrema_loss = self._replay_standard(extrema_batch)
|
|
|
|
# Reset learning rate
|
|
self.optimizer.param_groups[0]['lr'] = old_lr
|
|
|
|
# Log extrema loss
|
|
logger.info(f"Extra training on extrema points, loss: {extrema_loss:.4f}")
|
|
|
|
# Randomly train on price movement examples (similar to extrema)
|
|
if random.random() < 0.3 and len(self.price_movement_memory) >= self.batch_size:
|
|
# Train specifically on price movement memory examples
|
|
price_indices = np.random.choice(len(self.price_movement_memory), size=min(self.batch_size, len(self.price_movement_memory)), replace=False)
|
|
price_batch = [self.price_movement_memory[i] for i in price_indices]
|
|
|
|
# Extract tensors from price movement batch
|
|
price_states = torch.FloatTensor(np.array([e[0] for e in price_batch])).to(self.device)
|
|
price_actions = torch.LongTensor(np.array([e[1] for e in price_batch])).to(self.device)
|
|
price_rewards = torch.FloatTensor(np.array([e[2] for e in price_batch])).to(self.device)
|
|
price_next_states = torch.FloatTensor(np.array([e[3] for e in price_batch])).to(self.device)
|
|
price_dones = torch.FloatTensor(np.array([e[4] for e in price_batch])).to(self.device)
|
|
|
|
# Use a slightly reduced learning rate for price movement training
|
|
old_lr = self.optimizer.param_groups[0]['lr']
|
|
self.optimizer.param_groups[0]['lr'] = old_lr * 0.75
|
|
|
|
# Train on price movement memory
|
|
if self.use_mixed_precision:
|
|
price_loss = self._replay_mixed_precision(price_states, price_actions, price_rewards, price_next_states, price_dones)
|
|
else:
|
|
price_loss = self._replay_standard(price_batch)
|
|
|
|
# Reset learning rate
|
|
self.optimizer.param_groups[0]['lr'] = old_lr
|
|
|
|
# Log price movement loss
|
|
logger.info(f"Extra training on price movement examples, loss: {price_loss:.4f}")
|
|
|
|
return loss
|
|
|
|
def _replay_standard(self, experiences=None):
|
|
"""Standard training step without mixed precision"""
|
|
try:
|
|
# Use experiences if provided, otherwise sample from memory
|
|
if experiences is None:
|
|
# If memory is too small, skip training
|
|
if len(self.memory) < self.batch_size:
|
|
return 0.0
|
|
|
|
# Sample random mini-batch from memory
|
|
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
|
batch = [self.memory[i] for i in indices]
|
|
experiences = batch
|
|
|
|
# Unpack experiences
|
|
states, actions, rewards, next_states, dones = zip(*experiences)
|
|
|
|
# Convert to PyTorch tensors
|
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
|
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
|
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
|
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
|
|
|
# Get current Q values
|
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Get next Q values with target network
|
|
with torch.no_grad():
|
|
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
|
|
# Check for dimension mismatch between rewards and next_q_values
|
|
if rewards.shape[0] != next_q_values.shape[0]:
|
|
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
|
# Use the smaller size to prevent index error
|
|
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
|
rewards = rewards[:min_size]
|
|
dones = dones[:min_size]
|
|
next_q_values = next_q_values[:min_size]
|
|
current_q_values = current_q_values[:min_size]
|
|
|
|
# Calculate target Q values
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Compute loss for Q value
|
|
q_loss = self.criterion(current_q_values, target_q_values)
|
|
|
|
# Try to compute extrema loss if possible
|
|
try:
|
|
# Get the target classes from extrema predictions
|
|
extrema_targets = torch.argmax(current_extrema_pred, dim=1).long()
|
|
|
|
# Compute extrema loss using cross-entropy - this is an auxiliary task
|
|
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
|
|
|
# Combined loss with emphasis on Q-learning
|
|
total_loss = q_loss + 0.1 * extrema_loss
|
|
except Exception as e:
|
|
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
|
total_loss = q_loss
|
|
|
|
# Reset gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
# Backward pass
|
|
total_loss.backward()
|
|
|
|
# Clip gradients to avoid exploding gradients
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
|
|
# Update weights
|
|
self.optimizer.step()
|
|
|
|
# Update target network if needed
|
|
self.update_count += 1
|
|
if self.update_count % self.target_update == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Return loss
|
|
return total_loss.item()
|
|
except Exception as e:
|
|
logger.error(f"Error in replay standard: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return 0.0
|
|
|
|
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
|
|
"""Mixed precision training step for better GPU performance"""
|
|
# Check if mixed precision should be explicitly disabled
|
|
if 'DISABLE_MIXED_PRECISION' in os.environ:
|
|
logger.info("Mixed precision explicitly disabled by environment variable")
|
|
return self._replay_standard(states, actions, rewards, next_states, dones)
|
|
|
|
try:
|
|
# Zero gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
# Forward pass with amp autocasting
|
|
with torch.cuda.amp.autocast():
|
|
# Get current Q values and extrema predictions
|
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Get next Q values from target network
|
|
with torch.no_grad():
|
|
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
|
|
# Check for dimension mismatch and fix it
|
|
if rewards.shape[0] != next_q_values.shape[0]:
|
|
# Log the shape mismatch for debugging
|
|
logger.warning(f"Shape mismatch detected: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
|
# Use the smaller size to prevent index errors
|
|
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
|
rewards = rewards[:min_size]
|
|
dones = dones[:min_size]
|
|
next_q_values = next_q_values[:min_size]
|
|
current_q_values = current_q_values[:min_size]
|
|
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Compute Q-value loss (primary task)
|
|
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
|
|
|
# Initialize loss with q_loss
|
|
loss = q_loss
|
|
|
|
# Try to extract price from current and next states
|
|
try:
|
|
# Extract price feature from sequence data (if available)
|
|
if len(states.shape) == 3: # [batch, seq, features]
|
|
current_prices = states[:, -1, -1] # Last timestep, last feature
|
|
next_prices = next_states[:, -1, -1]
|
|
else: # [batch, features]
|
|
current_prices = states[:, -1] # Last feature
|
|
next_prices = next_states[:, -1]
|
|
|
|
# Calculate price change for different timeframes
|
|
immediate_changes = (next_prices - current_prices) / current_prices
|
|
|
|
# Get the actual batch size for this calculation
|
|
actual_batch_size = states.shape[0]
|
|
|
|
# Create price direction labels - simplified for training
|
|
# 0 = down, 1 = sideways, 2 = up
|
|
immediate_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
|
midterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
|
|
longterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
|
|
|
|
# Immediate term direction (1s, 1m)
|
|
immediate_up = (immediate_changes > 0.0005)
|
|
immediate_down = (immediate_changes < -0.0005)
|
|
immediate_labels[immediate_up] = 2 # Up
|
|
immediate_labels[immediate_down] = 0 # Down
|
|
|
|
# For mid and long term, we can only approximate during training
|
|
# In a real system, we'd need historical data to validate these
|
|
# Here we'll use the immediate term with increasing thresholds as approximation
|
|
|
|
# Mid-term (1h) - use slightly higher threshold
|
|
midterm_up = (immediate_changes > 0.001)
|
|
midterm_down = (immediate_changes < -0.001)
|
|
midterm_labels[midterm_up] = 2 # Up
|
|
midterm_labels[midterm_down] = 0 # Down
|
|
|
|
# Long-term (1d) - use even higher threshold
|
|
longterm_up = (immediate_changes > 0.002)
|
|
longterm_down = (immediate_changes < -0.002)
|
|
longterm_labels[longterm_up] = 2 # Up
|
|
longterm_labels[longterm_down] = 0 # Down
|
|
|
|
# Generate target values for price change regression
|
|
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
|
price_value_targets = torch.zeros((actual_batch_size, 4), device=self.device)
|
|
price_value_targets[:, 0] = immediate_changes
|
|
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
|
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
|
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
|
|
|
# Calculate loss for price direction prediction (classification)
|
|
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= actual_batch_size:
|
|
# Slice predictions to match the adjusted batch size
|
|
immediate_pred = current_price_pred['immediate'][:actual_batch_size]
|
|
midterm_pred = current_price_pred['midterm'][:actual_batch_size]
|
|
longterm_pred = current_price_pred['longterm'][:actual_batch_size]
|
|
price_values_pred = current_price_pred['values'][:actual_batch_size]
|
|
|
|
# Compute losses for each task
|
|
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
|
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
|
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
|
|
|
# MSE loss for price value regression
|
|
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
|
|
|
# Combine all price prediction losses
|
|
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
|
|
|
# Create extrema labels (same as before)
|
|
extrema_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
|
|
|
# Identify potential bottoms (significant negative change)
|
|
bottoms = (immediate_changes < -0.003)
|
|
extrema_labels[bottoms] = 0
|
|
|
|
# Identify potential tops (significant positive change)
|
|
tops = (immediate_changes > 0.003)
|
|
extrema_labels[tops] = 1
|
|
|
|
# Calculate extrema prediction loss
|
|
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= actual_batch_size:
|
|
current_extrema_pred = current_extrema_pred[:actual_batch_size]
|
|
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
|
|
|
# Combined loss with all components
|
|
# Primary task: Q-value learning (RL objective)
|
|
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
|
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
|
|
|
# Log loss components occasionally
|
|
if random.random() < 0.01: # Log 1% of the time
|
|
logger.info(
|
|
f"Mixed precision losses: Q-loss={q_loss.item():.4f}, "
|
|
f"Extrema-loss={extrema_loss.item():.4f}, "
|
|
f"Price-loss={price_loss.item():.4f}"
|
|
)
|
|
except Exception as e:
|
|
# Fallback if price extraction fails
|
|
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
|
# Just use Q-value loss
|
|
loss = q_loss
|
|
|
|
# Backward pass with scaled gradients
|
|
self.scaler.scale(loss).backward()
|
|
|
|
# Gradient clipping on scaled gradients
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
|
|
# Update with scaler
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
# Update target network if needed
|
|
self.update_count += 1
|
|
if self.update_count % self.target_update == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Track and decay epsilon
|
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
|
|
|
return loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in mixed precision training: {str(e)}")
|
|
logger.warning("Falling back to standard precision training")
|
|
# Fall back to standard training
|
|
return self._replay_standard(states, actions, rewards, next_states, dones)
|
|
|
|
def train_on_extrema(self, states, actions, rewards, next_states, dones):
|
|
"""
|
|
Special training function specifically for extrema points
|
|
|
|
Args:
|
|
states: Batch of states at extrema points
|
|
actions: Batch of actions
|
|
rewards: Batch of rewards
|
|
next_states: Batch of next states
|
|
dones: Batch of done flags
|
|
|
|
Returns:
|
|
float: Training loss
|
|
"""
|
|
# Convert to numpy arrays if not already
|
|
if not isinstance(states, np.ndarray):
|
|
states = np.array(states)
|
|
if not isinstance(actions, np.ndarray):
|
|
actions = np.array(actions)
|
|
if not isinstance(rewards, np.ndarray):
|
|
rewards = np.array(rewards)
|
|
if not isinstance(next_states, np.ndarray):
|
|
next_states = np.array(next_states)
|
|
if not isinstance(dones, np.ndarray):
|
|
dones = np.array(dones, dtype=np.float32)
|
|
|
|
# Normalize states
|
|
states = np.vstack([self._normalize_state(s) for s in states])
|
|
next_states = np.vstack([self._normalize_state(s) for s in next_states])
|
|
|
|
# Convert to torch tensors and move to device
|
|
states_tensor = torch.FloatTensor(states).to(self.device)
|
|
actions_tensor = torch.LongTensor(actions).to(self.device)
|
|
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
|
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
|
|
dones_tensor = torch.FloatTensor(dones).to(self.device)
|
|
|
|
# Choose training method based on precision mode
|
|
if self.use_mixed_precision:
|
|
return self._replay_mixed_precision(
|
|
states_tensor, actions_tensor, rewards_tensor,
|
|
next_states_tensor, dones_tensor
|
|
)
|
|
else:
|
|
return self._replay_standard(
|
|
states_tensor, actions_tensor, rewards_tensor,
|
|
next_states_tensor, dones_tensor
|
|
)
|
|
|
|
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
|
"""Normalize the state data to prevent numerical issues"""
|
|
# Handle NaN and infinite values
|
|
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
|
|
# Check if state is 1D array (happens in some environments)
|
|
if len(state.shape) == 1:
|
|
# If 1D, we need to normalize the whole array
|
|
normalized_state = state.copy()
|
|
|
|
# Convert any timestamp or non-numeric data to float
|
|
for i in range(len(normalized_state)):
|
|
# Check for timestamp-like objects
|
|
if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')):
|
|
# Convert timestamp to float (seconds since epoch)
|
|
normalized_state[i] = float(normalized_state[i].timestamp())
|
|
elif not isinstance(normalized_state[i], (int, float, np.number)):
|
|
# Set non-numeric data to 0
|
|
normalized_state[i] = 0.0
|
|
|
|
# Ensure all values are float
|
|
normalized_state = normalized_state.astype(np.float32)
|
|
|
|
# Simple min-max normalization for 1D state
|
|
state_min = np.min(normalized_state)
|
|
state_max = np.max(normalized_state)
|
|
if state_max > state_min:
|
|
normalized_state = (normalized_state - state_min) / (state_max - state_min)
|
|
return normalized_state
|
|
|
|
# Handle 2D arrays
|
|
normalized_state = np.zeros_like(state, dtype=np.float32)
|
|
|
|
# Convert any timestamp or non-numeric data to float
|
|
for i in range(state.shape[0]):
|
|
for j in range(state.shape[1]):
|
|
if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')):
|
|
# Convert timestamp to float (seconds since epoch)
|
|
normalized_state[i, j] = float(state[i, j].timestamp())
|
|
elif isinstance(state[i, j], (int, float, np.number)):
|
|
normalized_state[i, j] = state[i, j]
|
|
else:
|
|
# Set non-numeric data to 0
|
|
normalized_state[i, j] = 0.0
|
|
|
|
# Loop through each timeframe's features in the combined state
|
|
feature_count = state.shape[1] // len(self.timeframes)
|
|
|
|
for tf_idx in range(len(self.timeframes)):
|
|
start_idx = tf_idx * feature_count
|
|
end_idx = start_idx + feature_count
|
|
|
|
# Extract this timeframe's features
|
|
tf_features = normalized_state[:, start_idx:end_idx]
|
|
|
|
# Normalize OHLCV data by the first close price in the window
|
|
# This makes price movements relative rather than absolute
|
|
price_idx = 3 # Assuming close price is at index 3
|
|
if price_idx < tf_features.shape[1]:
|
|
reference_price = np.mean(tf_features[:, price_idx])
|
|
if reference_price != 0:
|
|
# Normalize price-related columns (OHLC)
|
|
for i in range(4): # First 4 columns are OHLC
|
|
if i < tf_features.shape[1]:
|
|
normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price
|
|
|
|
# Normalize volume using mean and std
|
|
vol_idx = 4 # Assuming volume is at index 4
|
|
if vol_idx < tf_features.shape[1]:
|
|
vol_mean = np.mean(tf_features[:, vol_idx])
|
|
vol_std = np.std(tf_features[:, vol_idx])
|
|
if vol_std > 0:
|
|
normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std
|
|
else:
|
|
normalized_state[:, start_idx + vol_idx] = 0
|
|
|
|
# Other features (technical indicators) - normalize with min-max scaling
|
|
for i in range(5, feature_count):
|
|
if i < tf_features.shape[1]:
|
|
feature_min = np.min(tf_features[:, i])
|
|
feature_max = np.max(tf_features[:, i])
|
|
if feature_max > feature_min:
|
|
normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min)
|
|
else:
|
|
normalized_state[:, start_idx + i] = 0
|
|
|
|
return normalized_state
|
|
|
|
def update_realtime_tick_features(self, tick_features):
|
|
"""Update with real-time tick features from tick processor"""
|
|
try:
|
|
if tick_features is not None:
|
|
self.realtime_tick_features = tick_features
|
|
|
|
# Log high-confidence tick features
|
|
if tick_features.get('confidence', 0) > 0.8:
|
|
logger.debug(f"High-confidence tick features updated: confidence={tick_features['confidence']:.3f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating real-time tick features: {e}")
|
|
|
|
def _enhance_state_with_tick_features(self, state: np.ndarray) -> np.ndarray:
|
|
"""Enhance state with real-time tick features if available"""
|
|
try:
|
|
if self.realtime_tick_features is None:
|
|
return state
|
|
|
|
# Extract neural features from tick processor
|
|
neural_features = self.realtime_tick_features.get('neural_features', np.array([]))
|
|
volume_features = self.realtime_tick_features.get('volume_features', np.array([]))
|
|
microstructure_features = self.realtime_tick_features.get('microstructure_features', np.array([]))
|
|
confidence = self.realtime_tick_features.get('confidence', 0.0)
|
|
|
|
# Combine tick features - make them compact to match state dimensions
|
|
tick_features = np.concatenate([
|
|
neural_features[:3] if len(neural_features) >= 3 else np.zeros(3), # Take first 3 neural features
|
|
volume_features[:1] if len(volume_features) >= 1 else np.zeros(1), # Take first volume feature
|
|
microstructure_features[:1] if len(microstructure_features) >= 1 else np.zeros(1), # Take first microstructure feature
|
|
])
|
|
|
|
# Weight the tick features
|
|
weighted_tick_features = tick_features * self.tick_feature_weight
|
|
|
|
# Enhance the state by adding tick features to each timeframe
|
|
if len(state.shape) == 1:
|
|
# 1D state - append tick features
|
|
enhanced_state = np.concatenate([state, weighted_tick_features])
|
|
else:
|
|
# 2D state - add tick features to each timeframe row
|
|
num_timeframes, num_features = state.shape
|
|
|
|
# Ensure tick features match the number of original features
|
|
if len(weighted_tick_features) != num_features:
|
|
# Pad or truncate tick features to match state feature dimension
|
|
if len(weighted_tick_features) < num_features:
|
|
# Pad with zeros
|
|
padded_features = np.zeros(num_features)
|
|
padded_features[:len(weighted_tick_features)] = weighted_tick_features
|
|
weighted_tick_features = padded_features
|
|
else:
|
|
# Truncate to match
|
|
weighted_tick_features = weighted_tick_features[:num_features]
|
|
|
|
# Add tick features to the last row (most recent timeframe)
|
|
enhanced_state = state.copy()
|
|
enhanced_state[-1, :] += weighted_tick_features # Add to last timeframe
|
|
|
|
return enhanced_state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error enhancing state with tick features: {e}")
|
|
return state
|
|
|
|
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
|
|
"""Update learning metrics and perform learning rate adjustments if needed"""
|
|
# Update average reward with exponential moving average
|
|
if self.avg_reward == 0:
|
|
self.avg_reward = episode_reward
|
|
else:
|
|
self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward
|
|
|
|
# Check if we're making sufficient progress
|
|
if episode_reward > (1 + best_reward_threshold) * self.best_reward:
|
|
self.best_reward = episode_reward
|
|
self.no_improvement_count = 0
|
|
return True # Improved
|
|
else:
|
|
self.no_improvement_count += 1
|
|
|
|
# If no improvement for a while, adjust learning rate
|
|
if self.no_improvement_count >= 10:
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
new_lr = current_lr * 0.5
|
|
if new_lr >= 1e-6: # Don't reduce below minimum threshold
|
|
for param_group in self.optimizer.param_groups:
|
|
param_group['lr'] = new_lr
|
|
logger.info(f"Reducing learning rate from {current_lr} to {new_lr}")
|
|
self.no_improvement_count = 0
|
|
|
|
return False # No improvement
|
|
|
|
def save(self, path: str):
|
|
"""Save model and agent state"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
# Save policy network
|
|
self.policy_net.save(f"{path}_policy")
|
|
|
|
# Save target network
|
|
self.target_net.save(f"{path}_target")
|
|
|
|
# Save agent state
|
|
state = {
|
|
'epsilon': self.epsilon,
|
|
'update_count': self.update_count,
|
|
'losses': self.losses,
|
|
'optimizer_state': self.optimizer.state_dict(),
|
|
'best_reward': self.best_reward,
|
|
'avg_reward': self.avg_reward
|
|
}
|
|
|
|
torch.save(state, f"{path}_agent_state.pt")
|
|
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
|
|
|
def load(self, path: str):
|
|
"""Load model and agent state"""
|
|
# Load policy network
|
|
self.policy_net.load(f"{path}_policy")
|
|
|
|
# Load target network
|
|
self.target_net.load(f"{path}_target")
|
|
|
|
# Load agent state
|
|
try:
|
|
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
|
|
self.epsilon = agent_state['epsilon']
|
|
self.update_count = agent_state['update_count']
|
|
self.losses = agent_state['losses']
|
|
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
|
|
|
# Load additional metrics if they exist
|
|
if 'best_reward' in agent_state:
|
|
self.best_reward = agent_state['best_reward']
|
|
if 'avg_reward' in agent_state:
|
|
self.avg_reward = agent_state['avg_reward']
|
|
|
|
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
|
except FileNotFoundError:
|
|
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values") |