329 lines
14 KiB
Python
329 lines
14 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
from typing import Tuple, List
|
|
import os
|
|
import sys
|
|
import logging
|
|
import torch.nn.functional as F
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
# Import the EnhancedCNN model
|
|
from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EnhancedDQNAgent:
|
|
"""
|
|
Enhanced Deep Q-Network agent for trading
|
|
Uses the improved EnhancedCNN model with residual connections and attention mechanisms
|
|
"""
|
|
def __init__(self,
|
|
state_shape: Tuple[int, ...],
|
|
n_actions: int,
|
|
learning_rate: float = 0.0003, # Slightly reduced learning rate for stability
|
|
gamma: float = 0.95, # Discount factor
|
|
epsilon: float = 1.0,
|
|
epsilon_min: float = 0.05,
|
|
epsilon_decay: float = 0.995, # Slower decay for more exploration
|
|
buffer_size: int = 50000, # Larger memory buffer
|
|
batch_size: int = 128, # Larger batch size
|
|
target_update: int = 10, # More frequent target updates
|
|
confidence_threshold: float = 0.4, # Lower confidence threshold
|
|
device=None):
|
|
|
|
# Extract state dimensions
|
|
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
|
# Multi-dimensional state (like image or sequence)
|
|
self.state_dim = state_shape
|
|
else:
|
|
# 1D state
|
|
if isinstance(state_shape, tuple):
|
|
self.state_dim = state_shape[0]
|
|
else:
|
|
self.state_dim = state_shape
|
|
|
|
# Store parameters
|
|
self.n_actions = n_actions
|
|
self.learning_rate = learning_rate
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.epsilon_min = epsilon_min
|
|
self.epsilon_decay = epsilon_decay
|
|
self.buffer_size = buffer_size
|
|
self.batch_size = batch_size
|
|
self.target_update = target_update
|
|
self.confidence_threshold = confidence_threshold
|
|
|
|
# Set device for computation
|
|
if device is None:
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
else:
|
|
self.device = device
|
|
|
|
# Initialize models with the enhanced CNN
|
|
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
|
self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
|
|
|
# Initialize the target network with the same weights as the policy network
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Set models to eval mode (important for batch norm, dropout)
|
|
self.target_net.eval()
|
|
|
|
# Optimization components
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
# Experience replay memory with example sifting
|
|
self.memory = ExampleSiftingDataset(max_examples=buffer_size)
|
|
self.update_count = 0
|
|
|
|
# Confidence tracking
|
|
self.confidence_history = []
|
|
self.avg_confidence = 0.0
|
|
self.max_confidence = 0.0
|
|
self.min_confidence = 1.0
|
|
|
|
# Performance tracking
|
|
self.losses = []
|
|
self.rewards = []
|
|
self.avg_reward = 0.0
|
|
|
|
# Check if mixed precision training should be used
|
|
self.use_mixed_precision = False
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
|
self.use_mixed_precision = True
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
logger.info("Mixed precision training enabled")
|
|
else:
|
|
logger.info("Mixed precision training disabled")
|
|
|
|
# For compatibility with old code
|
|
self.action_size = n_actions
|
|
|
|
logger.info(f"Enhanced DQN Agent using device: {self.device}")
|
|
logger.info(f"Confidence threshold set to {self.confidence_threshold}")
|
|
|
|
def move_models_to_device(self, device=None):
|
|
"""Move models to the specified device (GPU/CPU)"""
|
|
if device is not None:
|
|
self.device = device
|
|
|
|
try:
|
|
self.policy_net = self.policy_net.to(self.device)
|
|
self.target_net = self.target_net.to(self.device)
|
|
logger.info(f"Moved models to {self.device}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
|
return False
|
|
|
|
def _normalize_state(self, state):
|
|
"""Normalize state for better training stability"""
|
|
try:
|
|
# Convert to numpy array if needed
|
|
if isinstance(state, list):
|
|
state = np.array(state, dtype=np.float32)
|
|
|
|
# Apply normalization based on state shape
|
|
if len(state.shape) > 1:
|
|
# Multi-dimensional state - normalize each feature dimension separately
|
|
for i in range(state.shape[0]):
|
|
# Skip if all zeros (to avoid division by zero)
|
|
if np.sum(np.abs(state[i])) > 0:
|
|
# Standardize each feature dimension
|
|
mean = np.mean(state[i])
|
|
std = np.std(state[i])
|
|
if std > 0:
|
|
state[i] = (state[i] - mean) / std
|
|
else:
|
|
# 1D state vector
|
|
# Skip if all zeros
|
|
if np.sum(np.abs(state)) > 0:
|
|
mean = np.mean(state)
|
|
std = np.std(state)
|
|
if std > 0:
|
|
state = (state - mean) / std
|
|
|
|
return state
|
|
except Exception as e:
|
|
logger.warning(f"Error normalizing state: {str(e)}")
|
|
return state
|
|
|
|
def remember(self, state, action, reward, next_state, done):
|
|
"""Store experience in memory with example sifting"""
|
|
self.memory.add_example(state, action, reward, next_state, done)
|
|
|
|
# Also track rewards for monitoring
|
|
self.rewards.append(reward)
|
|
if len(self.rewards) > 100:
|
|
self.rewards = self.rewards[-100:]
|
|
self.avg_reward = np.mean(self.rewards)
|
|
|
|
def act(self, state, explore=True):
|
|
"""Choose action using epsilon-greedy policy with built-in confidence thresholding"""
|
|
if explore and random.random() < self.epsilon:
|
|
return random.randrange(self.n_actions), 0.0 # Return action and zero confidence
|
|
|
|
# Normalize state before inference
|
|
normalized_state = self._normalize_state(state)
|
|
|
|
# Use the EnhancedCNN's act method which includes confidence thresholding
|
|
action, confidence = self.policy_net.act(normalized_state, explore=explore)
|
|
|
|
# Track confidence metrics
|
|
self.confidence_history.append(confidence)
|
|
if len(self.confidence_history) > 100:
|
|
self.confidence_history = self.confidence_history[-100:]
|
|
|
|
# Update confidence metrics
|
|
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
|
self.max_confidence = max(self.max_confidence, confidence)
|
|
self.min_confidence = min(self.min_confidence, confidence)
|
|
|
|
# Log average confidence occasionally
|
|
if random.random() < 0.01: # 1% of the time
|
|
logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
|
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
|
|
|
return action, confidence
|
|
|
|
def replay(self):
|
|
"""Train the model using experience replay with high-quality examples"""
|
|
# Check if enough samples in memory
|
|
if len(self.memory) < self.batch_size:
|
|
return 0.0
|
|
|
|
# Get batch of experiences
|
|
batch = self.memory.get_batch(self.batch_size)
|
|
if batch is None:
|
|
return 0.0
|
|
|
|
states = torch.FloatTensor(batch['states']).to(self.device)
|
|
actions = torch.LongTensor(batch['actions']).to(self.device)
|
|
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
|
|
next_states = torch.FloatTensor(batch['next_states']).to(self.device)
|
|
dones = torch.FloatTensor(batch['dones']).to(self.device)
|
|
|
|
# Compute Q values
|
|
self.policy_net.train() # Set to training mode
|
|
|
|
# Get current Q values
|
|
if self.use_mixed_precision:
|
|
with torch.cuda.amp.autocast():
|
|
# Get current Q values
|
|
q_values, _, _, _ = self.policy_net(states)
|
|
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Compute target Q values
|
|
with torch.no_grad():
|
|
self.target_net.eval()
|
|
next_q_values, _, _, _ = self.target_net(next_states)
|
|
next_q = next_q_values.max(1)[0]
|
|
target_q = rewards + (1 - dones) * self.gamma * next_q
|
|
|
|
# Compute loss
|
|
loss = self.criterion(current_q, target_q)
|
|
|
|
# Perform backpropagation with mixed precision
|
|
self.optimizer.zero_grad()
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
# Standard precision training
|
|
# Get current Q values
|
|
q_values, _, _, _ = self.policy_net(states)
|
|
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Compute target Q values
|
|
with torch.no_grad():
|
|
self.target_net.eval()
|
|
next_q_values, _, _, _ = self.target_net(next_states)
|
|
next_q = next_q_values.max(1)[0]
|
|
target_q = rewards + (1 - dones) * self.gamma * next_q
|
|
|
|
# Compute loss
|
|
loss = self.criterion(current_q, target_q)
|
|
|
|
# Perform backpropagation
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
self.optimizer.step()
|
|
|
|
# Track loss
|
|
loss_value = loss.item()
|
|
self.losses.append(loss_value)
|
|
if len(self.losses) > 100:
|
|
self.losses = self.losses[-100:]
|
|
|
|
# Update target network
|
|
self.update_count += 1
|
|
if self.update_count % self.target_update == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
logger.info(f"Updated target network (step {self.update_count})")
|
|
|
|
# Decay epsilon
|
|
if self.epsilon > self.epsilon_min:
|
|
self.epsilon *= self.epsilon_decay
|
|
|
|
return loss_value
|
|
|
|
def save(self, path):
|
|
"""Save agent state and models"""
|
|
self.policy_net.save(f"{path}_policy")
|
|
self.target_net.save(f"{path}_target")
|
|
|
|
# Save agent state
|
|
torch.save({
|
|
'epsilon': self.epsilon,
|
|
'confidence_threshold': self.confidence_threshold,
|
|
'losses': self.losses,
|
|
'rewards': self.rewards,
|
|
'avg_reward': self.avg_reward,
|
|
'confidence_history': self.confidence_history,
|
|
'avg_confidence': self.avg_confidence,
|
|
'max_confidence': self.max_confidence,
|
|
'min_confidence': self.min_confidence,
|
|
'update_count': self.update_count
|
|
}, f"{path}_agent_state.pt")
|
|
|
|
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
|
|
|
def load(self, path):
|
|
"""Load agent state and models"""
|
|
policy_loaded = self.policy_net.load(f"{path}_policy")
|
|
target_loaded = self.target_net.load(f"{path}_target")
|
|
|
|
# Load agent state if available
|
|
agent_state_path = f"{path}_agent_state.pt"
|
|
if os.path.exists(agent_state_path):
|
|
try:
|
|
state = torch.load(agent_state_path)
|
|
self.epsilon = state.get('epsilon', self.epsilon)
|
|
self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold)
|
|
self.policy_net.confidence_threshold = self.confidence_threshold
|
|
self.target_net.confidence_threshold = self.confidence_threshold
|
|
self.losses = state.get('losses', [])
|
|
self.rewards = state.get('rewards', [])
|
|
self.avg_reward = state.get('avg_reward', 0.0)
|
|
self.confidence_history = state.get('confidence_history', [])
|
|
self.avg_confidence = state.get('avg_confidence', 0.0)
|
|
self.max_confidence = state.get('max_confidence', 0.0)
|
|
self.min_confidence = state.get('min_confidence', 1.0)
|
|
self.update_count = state.get('update_count', 0)
|
|
logger.info(f"Agent state loaded from {agent_state_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading agent state: {str(e)}")
|
|
|
|
return policy_loaded and target_loaded |