added leverage slider
This commit is contained in:
@ -9,6 +9,7 @@ import os
|
||||
import sys
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
@ -23,16 +24,16 @@ class DQNAgent:
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int,
|
||||
learning_rate: float = 0.0005, # Reduced learning rate for more stability
|
||||
gamma: float = 0.97, # Slightly reduced discount factor
|
||||
n_actions: int = 2,
|
||||
learning_rate: float = 0.001,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
|
||||
epsilon_decay: float = 0.9975, # Slower decay rate
|
||||
buffer_size: int = 20000, # Increased memory size
|
||||
batch_size: int = 128, # Larger batch size
|
||||
target_update: int = 5, # More frequent target updates
|
||||
device=None): # Device for computations
|
||||
epsilon_min: float = 0.01,
|
||||
epsilon_decay: float = 0.995,
|
||||
buffer_size: int = 10000,
|
||||
batch_size: int = 32,
|
||||
target_update: int = 100,
|
||||
priority_memory: bool = True,
|
||||
device=None):
|
||||
|
||||
# Extract state dimensions
|
||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
||||
@ -48,11 +49,9 @@ class DQNAgent:
|
||||
# Store parameters
|
||||
self.n_actions = n_actions
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
@ -127,10 +126,41 @@ class DQNAgent:
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Enhanced features from EnhancedDQNAgent
|
||||
# Market adaptation capabilities
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Dueling network support (requires enhanced network architecture)
|
||||
self.use_dueling = True
|
||||
|
||||
# Prioritized experience replay parameters
|
||||
self.use_prioritized_replay = priority_memory
|
||||
self.alpha = 0.6 # Priority exponent
|
||||
self.beta = 0.4 # Importance sampling exponent
|
||||
self.beta_increment = 0.001
|
||||
|
||||
# Double DQN support
|
||||
self.use_double_dqn = True
|
||||
|
||||
# Enhanced training features from EnhancedDQNAgent
|
||||
self.target_update_freq = target_update # More descriptive name
|
||||
self.training_steps = 0
|
||||
self.gradient_clip_norm = 1.0 # Gradient clipping
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history = []
|
||||
self.td_errors = [] # Track TD errors for analysis
|
||||
|
||||
# Trade action fee and confidence thresholds
|
||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
||||
self.recent_actions = [] # Track recent actions to avoid oscillations
|
||||
self.recent_actions = deque(maxlen=10)
|
||||
self.recent_prices = deque(maxlen=20)
|
||||
self.recent_rewards = deque(maxlen=100)
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
@ -173,6 +203,16 @@ class DQNAgent:
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||
|
||||
# Position management for 2-action system
|
||||
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
|
||||
# Different thresholds for entry vs exit decisions
|
||||
self.entry_confidence_threshold = 0.7 # High threshold for new positions
|
||||
self.exit_confidence_threshold = 0.3 # Lower threshold for closing positions
|
||||
self.uncertainty_threshold = 0.1 # When to stay neutral
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
if device is not None:
|
||||
@ -290,247 +330,148 @@ class DQNAgent:
|
||||
if len(self.price_movement_memory) > self.buffer_size // 4:
|
||||
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
||||
|
||||
def act(self, state: np.ndarray, explore=True) -> int:
|
||||
"""Choose action using epsilon-greedy policy with explore flag"""
|
||||
if explore and random.random() < self.epsilon:
|
||||
return random.randrange(self.n_actions)
|
||||
def act(self, state: np.ndarray, explore=True, current_price=None, market_context=None) -> int:
|
||||
"""
|
||||
Choose action based on current state using 2-action system with intelligent position management
|
||||
|
||||
with torch.no_grad():
|
||||
# Enhance state with real-time tick features
|
||||
enhanced_state = self._enhance_state_with_tick_features(state)
|
||||
Args:
|
||||
state: Current market state
|
||||
explore: Whether to use epsilon-greedy exploration
|
||||
current_price: Current market price for position management
|
||||
market_context: Additional market context for decision making
|
||||
|
||||
# Ensure state is normalized before inference
|
||||
state_tensor = self._normalize_state(enhanced_state)
|
||||
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Store hidden features for integration
|
||||
self.last_hidden_features = hidden_features.cpu().numpy()
|
||||
|
||||
# Track feature history (limited size)
|
||||
self.feature_history.append(hidden_features.cpu().numpy())
|
||||
if len(self.feature_history) > 100:
|
||||
self.feature_history = self.feature_history[-100:]
|
||||
|
||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
||||
extrema_class = extrema_pred.argmax(dim=1).item()
|
||||
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
|
||||
|
||||
# Log extrema prediction for significant signals
|
||||
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
|
||||
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
|
||||
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
|
||||
|
||||
# Process price predictions
|
||||
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
|
||||
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
|
||||
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
|
||||
price_values = price_predictions['values']
|
||||
|
||||
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
|
||||
immediate_direction = price_immediate.argmax(dim=1).item()
|
||||
midterm_direction = price_midterm.argmax(dim=1).item()
|
||||
longterm_direction = price_longterm.argmax(dim=1).item()
|
||||
|
||||
# Get confidence levels
|
||||
immediate_conf = price_immediate[0, immediate_direction].item()
|
||||
midterm_conf = price_midterm[0, midterm_direction].item()
|
||||
longterm_conf = price_longterm[0, longterm_direction].item()
|
||||
|
||||
# Get predicted price change percentages
|
||||
price_changes = price_values[0].tolist()
|
||||
|
||||
# Log significant price movement predictions
|
||||
timeframes = ["1s/1m", "1h", "1d", "1w"]
|
||||
directions = ["DOWN", "SIDEWAYS", "UP"]
|
||||
|
||||
for i, (direction, conf) in enumerate([
|
||||
(immediate_direction, immediate_conf),
|
||||
(midterm_direction, midterm_conf),
|
||||
(longterm_direction, longterm_conf)
|
||||
]):
|
||||
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
|
||||
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
|
||||
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
|
||||
|
||||
# Store predictions for environment to use
|
||||
self.last_extrema_pred = {
|
||||
'class': extrema_class,
|
||||
'confidence': extrema_confidence,
|
||||
'raw': extrema_pred.cpu().numpy()
|
||||
}
|
||||
|
||||
self.last_price_pred = {
|
||||
'immediate': {
|
||||
'direction': immediate_direction,
|
||||
'confidence': immediate_conf,
|
||||
'change': price_changes[0]
|
||||
},
|
||||
'midterm': {
|
||||
'direction': midterm_direction,
|
||||
'confidence': midterm_conf,
|
||||
'change': price_changes[1]
|
||||
},
|
||||
'longterm': {
|
||||
'direction': longterm_direction,
|
||||
'confidence': longterm_conf,
|
||||
'change': price_changes[2]
|
||||
}
|
||||
}
|
||||
|
||||
# Get the action with highest Q-value
|
||||
action = action_probs.argmax().item()
|
||||
|
||||
# Calculate overall confidence in the action
|
||||
q_values_softmax = F.softmax(action_probs, dim=1)[0]
|
||||
action_confidence = q_values_softmax[action].item()
|
||||
|
||||
# Track confidence metrics
|
||||
self.confidence_history.append(action_confidence)
|
||||
if len(self.confidence_history) > 100:
|
||||
self.confidence_history = self.confidence_history[-100:]
|
||||
|
||||
# Update confidence metrics
|
||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
||||
self.max_confidence = max(self.max_confidence, action_confidence)
|
||||
self.min_confidence = min(self.min_confidence, action_confidence)
|
||||
|
||||
# Log average confidence occasionally
|
||||
if random.random() < 0.01: # 1% of the time
|
||||
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
||||
|
||||
# Track price for violent move detection
|
||||
try:
|
||||
# Extract current price from state (assuming it's in the last position)
|
||||
if len(state.shape) > 1: # For 2D state
|
||||
current_price = state[-1, -1]
|
||||
else: # For 1D state
|
||||
current_price = state[-1]
|
||||
|
||||
self.price_history.append(current_price)
|
||||
if len(self.price_history) > self.volatility_window:
|
||||
self.price_history = self.price_history[-self.volatility_window:]
|
||||
|
||||
# Detect violent price moves if we have enough price history
|
||||
if len(self.price_history) >= 5:
|
||||
# Calculate short-term volatility
|
||||
recent_prices = self.price_history[-5:]
|
||||
|
||||
# Make sure we're working with scalar values, not arrays
|
||||
if isinstance(recent_prices[0], np.ndarray):
|
||||
# If prices are arrays, extract the last value (current price)
|
||||
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
|
||||
|
||||
# Calculate price changes with protection against division by zero
|
||||
price_changes = []
|
||||
for i in range(1, len(recent_prices)):
|
||||
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
|
||||
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
||||
price_changes.append(change)
|
||||
else:
|
||||
price_changes.append(0.0)
|
||||
|
||||
# Calculate volatility as sum of absolute price changes
|
||||
volatility = sum([abs(change) for change in price_changes])
|
||||
|
||||
# Check if we've had a violent move
|
||||
if volatility > self.volatility_threshold:
|
||||
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
|
||||
self.post_violent_move = True
|
||||
self.violent_move_cooldown = 10 # Set cooldown period
|
||||
|
||||
# Handle post-violent move period
|
||||
if self.post_violent_move:
|
||||
if self.violent_move_cooldown > 0:
|
||||
self.violent_move_cooldown -= 1
|
||||
# Increase confidence threshold temporarily after violent moves
|
||||
effective_threshold = self.minimum_action_confidence * 1.1
|
||||
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
|
||||
f"Using higher confidence threshold: {effective_threshold:.4f}")
|
||||
else:
|
||||
self.post_violent_move = False
|
||||
logger.info("Post-violent move period ended")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in violent move detection: {str(e)}")
|
||||
|
||||
# Apply trade action fee to buy/sell actions but not to hold
|
||||
# This creates a threshold that must be exceeded to justify a trade
|
||||
action_values = action_probs.clone()
|
||||
|
||||
# If BUY or SELL, apply fee by reducing the Q-value
|
||||
if action == 0 or action == 1: # BUY or SELL
|
||||
# Check if confidence is above minimum threshold
|
||||
effective_threshold = self.minimum_action_confidence
|
||||
if self.post_violent_move:
|
||||
effective_threshold *= 1.1 # Higher threshold after violent moves
|
||||
|
||||
if action_confidence < effective_threshold:
|
||||
# If confidence is below threshold, force HOLD action
|
||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
else:
|
||||
# Apply trade action fee to ensure we only trade when there's clear benefit
|
||||
fee_adjusted_action_values = action_values.clone()
|
||||
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
|
||||
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
|
||||
# Hold value remains unchanged
|
||||
|
||||
# Re-determine the action based on fee-adjusted values
|
||||
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
|
||||
|
||||
# If the fee changes our decision, log this
|
||||
if fee_adjusted_action != action:
|
||||
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
|
||||
action = fee_adjusted_action
|
||||
|
||||
# Adjust action based on extrema and price predictions
|
||||
# Prioritize short-term movement for trading decisions
|
||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
||||
if immediate_direction == 2: # UP prediction
|
||||
# Bias toward BUY for strong up predictions
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
||||
action = 0 # BUY
|
||||
elif immediate_direction == 0: # DOWN prediction
|
||||
# Bias toward SELL for strong down predictions
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
||||
action = 1 # SELL
|
||||
|
||||
# Also consider extrema detection for action adjustment
|
||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
||||
if extrema_class == 0: # Bottom detected
|
||||
# Bias toward BUY at bottoms
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to BUY based on bottom detection")
|
||||
action = 0 # BUY
|
||||
elif extrema_class == 1: # Top detected
|
||||
# Bias toward SELL at tops
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to SELL based on top detection")
|
||||
action = 1 # SELL
|
||||
|
||||
# Finally, avoid action oscillation by checking recent history
|
||||
if len(self.recent_actions) >= 2:
|
||||
last_action = self.recent_actions[-1]
|
||||
if action != last_action and action != 2 and last_action != 2:
|
||||
# We're switching between BUY and SELL too quickly
|
||||
# Only allow this if we have very high confidence
|
||||
if action_confidence < 0.85:
|
||||
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
|
||||
# Update recent actions list
|
||||
Returns:
|
||||
int: Action (0=SELL, 1=BUY) or None if should hold position
|
||||
"""
|
||||
|
||||
# Convert state to tensor
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
else:
|
||||
state_tensor = state.unsqueeze(0).to(self.device)
|
||||
|
||||
# Get Q-values
|
||||
q_values = self.policy_net(state_tensor)
|
||||
action_values = q_values.cpu().data.numpy()[0]
|
||||
|
||||
# Calculate confidence scores
|
||||
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
||||
buy_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
||||
|
||||
# Determine action based on current position and confidence thresholds
|
||||
action = self._determine_action_with_position_management(
|
||||
sell_confidence, buy_confidence, current_price, market_context, explore
|
||||
)
|
||||
|
||||
# Update tracking
|
||||
if current_price:
|
||||
self.recent_prices.append(current_price)
|
||||
|
||||
if action is not None:
|
||||
self.recent_actions.append(action)
|
||||
if len(self.recent_actions) > 5:
|
||||
self.recent_actions = self.recent_actions[-5:]
|
||||
|
||||
return action
|
||||
else:
|
||||
# Return None to indicate HOLD (don't change position)
|
||||
return None
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.policy_net(state_tensor)
|
||||
|
||||
# Convert Q-values to probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = q_values.argmax().item()
|
||||
base_confidence = action_probs[0, action].item()
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
return action, adapted_confidence
|
||||
|
||||
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||
"""
|
||||
Determine action based on current position and confidence thresholds
|
||||
|
||||
This implements the intelligent position management where:
|
||||
- When neutral: Need high confidence to enter position
|
||||
- When in position: Need lower confidence to exit
|
||||
- Different thresholds for entry vs exit
|
||||
"""
|
||||
|
||||
# Apply epsilon-greedy exploration
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
return np.random.choice([0, 1])
|
||||
|
||||
# Get the dominant signal
|
||||
dominant_action = 0 if sell_conf > buy_conf else 1
|
||||
dominant_confidence = max(sell_conf, buy_conf)
|
||||
|
||||
# Decision logic based on current position
|
||||
if self.current_position == 0: # No position - need high confidence to enter
|
||||
if dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Strong enough signal to enter position
|
||||
if dominant_action == 1: # BUY signal
|
||||
self.current_position = 1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||
return 1
|
||||
else: # SELL signal
|
||||
self.current_position = -1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||
return 0
|
||||
else:
|
||||
# Not confident enough to enter position
|
||||
return None
|
||||
|
||||
elif self.current_position > 0: # Long position
|
||||
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# SELL signal with enough confidence to close long position
|
||||
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
return 0
|
||||
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Very strong SELL signal - close long and enter short
|
||||
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = -1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
return 0
|
||||
else:
|
||||
# Hold the long position
|
||||
return None
|
||||
|
||||
elif self.current_position < 0: # Short position
|
||||
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# BUY signal with enough confidence to close short position
|
||||
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
return 1
|
||||
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Very strong BUY signal - close short and enter long
|
||||
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
return 1
|
||||
else:
|
||||
# Hold the short position
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def replay(self, experiences=None):
|
||||
"""Train the model using experiences from memory"""
|
||||
@ -658,10 +599,18 @@ class DQNAgent:
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values with target network
|
||||
# Enhanced Double DQN implementation
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
if self.use_double_dqn:
|
||||
# Double DQN: Use policy network to select actions, target network to evaluate
|
||||
policy_q_values, _, _, _, _ = self.policy_net(next_states)
|
||||
next_actions = policy_q_values.argmax(1)
|
||||
target_q_values_all, _, _, _, _ = self.target_net(next_states)
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN: Use target network for both selection and evaluation
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch between rewards and next_q_values
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
@ -699,16 +648,25 @@ class DQNAgent:
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Clip gradients to avoid exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
# Enhanced gradient clipping with configurable norm
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.gradient_clip_norm)
|
||||
|
||||
# Update weights
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
# Enhanced target network update tracking
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
logger.debug(f"Target network updated at step {self.training_steps}")
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history.append(self.epsilon)
|
||||
|
||||
# Calculate and store TD error for analysis
|
||||
with torch.no_grad():
|
||||
td_error = torch.abs(current_q_values - target_q_values).mean().item()
|
||||
self.td_errors.append(td_error)
|
||||
|
||||
# Return loss
|
||||
return total_loss.item()
|
||||
@ -1168,4 +1126,40 @@ class DQNAgent:
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
|
||||
def get_position_info(self):
|
||||
"""Get current position information"""
|
||||
return {
|
||||
'position': self.current_position,
|
||||
'entry_price': self.position_entry_price,
|
||||
'entry_time': self.position_entry_time,
|
||||
'entry_threshold': self.entry_confidence_threshold,
|
||||
'exit_threshold': self.exit_confidence_threshold
|
||||
}
|
||||
|
||||
def get_enhanced_training_stats(self):
|
||||
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
|
||||
return {
|
||||
'buffer_size': len(self.memory),
|
||||
'epsilon': self.epsilon,
|
||||
'avg_reward': self.avg_reward,
|
||||
'best_reward': self.best_reward,
|
||||
'recent_rewards': list(self.recent_rewards) if hasattr(self, 'recent_rewards') else [],
|
||||
'no_improvement_count': self.no_improvement_count,
|
||||
# Enhanced statistics from EnhancedDQNAgent
|
||||
'training_steps': self.training_steps,
|
||||
'avg_td_error': np.mean(self.td_errors[-100:]) if self.td_errors else 0.0,
|
||||
'recent_losses': self.losses[-10:] if self.losses else [],
|
||||
'epsilon_trend': self.epsilon_history[-20:] if self.epsilon_history else [],
|
||||
'specialized_buffers': {
|
||||
'extrema_memory': len(self.extrema_memory),
|
||||
'positive_memory': len(self.positive_memory),
|
||||
'price_movement_memory': len(self.price_movement_memory)
|
||||
},
|
||||
'market_regime_weights': self.market_regime_weights,
|
||||
'use_double_dqn': self.use_double_dqn,
|
||||
'use_prioritized_replay': self.use_prioritized_replay,
|
||||
'gradient_clip_norm': self.gradient_clip_norm,
|
||||
'target_update_frequency': self.target_update_freq
|
||||
}
|
Reference in New Issue
Block a user