gogo2/NN/models/dqn_agent_enhanced.py
Dobromir Popov c0872248ab misc
2025-05-13 17:19:52 +03:00

329 lines
14 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
import os
import sys
import logging
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# Import the EnhancedCNN model
from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset
# Configure logger
logger = logging.getLogger(__name__)
class EnhancedDQNAgent:
"""
Enhanced Deep Q-Network agent for trading
Uses the improved EnhancedCNN model with residual connections and attention mechanisms
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int,
learning_rate: float = 0.0003, # Slightly reduced learning rate for stability
gamma: float = 0.95, # Discount factor
epsilon: float = 1.0,
epsilon_min: float = 0.05,
epsilon_decay: float = 0.995, # Slower decay for more exploration
buffer_size: int = 50000, # Larger memory buffer
batch_size: int = 128, # Larger batch size
target_update: int = 10, # More frequent target updates
confidence_threshold: float = 0.4, # Lower confidence threshold
device=None):
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
# Multi-dimensional state (like image or sequence)
self.state_dim = state_shape
else:
# 1D state
if isinstance(state_shape, tuple):
self.state_dim = state_shape[0]
else:
self.state_dim = state_shape
# Store parameters
self.n_actions = n_actions
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.buffer_size = buffer_size
self.batch_size = batch_size
self.target_update = target_update
self.confidence_threshold = confidence_threshold
# Set device for computation
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# Initialize models with the enhanced CNN
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
# Set models to eval mode (important for batch norm, dropout)
self.target_net.eval()
# Optimization components
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
self.criterion = nn.MSELoss()
# Experience replay memory with example sifting
self.memory = ExampleSiftingDataset(max_examples=buffer_size)
self.update_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Performance tracking
self.losses = []
self.rewards = []
self.avg_reward = 0.0
# Check if mixed precision training should be used
self.use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
logger.info("Mixed precision training disabled")
# For compatibility with old code
self.action_size = n_actions
logger.info(f"Enhanced DQN Agent using device: {self.device}")
logger.info(f"Confidence threshold set to {self.confidence_threshold}")
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
self.device = device
try:
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
logger.info(f"Moved models to {self.device}")
return True
except Exception as e:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def _normalize_state(self, state):
"""Normalize state for better training stability"""
try:
# Convert to numpy array if needed
if isinstance(state, list):
state = np.array(state, dtype=np.float32)
# Apply normalization based on state shape
if len(state.shape) > 1:
# Multi-dimensional state - normalize each feature dimension separately
for i in range(state.shape[0]):
# Skip if all zeros (to avoid division by zero)
if np.sum(np.abs(state[i])) > 0:
# Standardize each feature dimension
mean = np.mean(state[i])
std = np.std(state[i])
if std > 0:
state[i] = (state[i] - mean) / std
else:
# 1D state vector
# Skip if all zeros
if np.sum(np.abs(state)) > 0:
mean = np.mean(state)
std = np.std(state)
if std > 0:
state = (state - mean) / std
return state
except Exception as e:
logger.warning(f"Error normalizing state: {str(e)}")
return state
def remember(self, state, action, reward, next_state, done):
"""Store experience in memory with example sifting"""
self.memory.add_example(state, action, reward, next_state, done)
# Also track rewards for monitoring
self.rewards.append(reward)
if len(self.rewards) > 100:
self.rewards = self.rewards[-100:]
self.avg_reward = np.mean(self.rewards)
def act(self, state, explore=True):
"""Choose action using epsilon-greedy policy with built-in confidence thresholding"""
if explore and random.random() < self.epsilon:
return random.randrange(self.n_actions), 0.0 # Return action and zero confidence
# Normalize state before inference
normalized_state = self._normalize_state(state)
# Use the EnhancedCNN's act method which includes confidence thresholding
action, confidence = self.policy_net.act(normalized_state, explore=explore)
# Track confidence metrics
self.confidence_history.append(confidence)
if len(self.confidence_history) > 100:
self.confidence_history = self.confidence_history[-100:]
# Update confidence metrics
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
self.max_confidence = max(self.max_confidence, confidence)
self.min_confidence = min(self.min_confidence, confidence)
# Log average confidence occasionally
if random.random() < 0.01: # 1% of the time
logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
return action, confidence
def replay(self):
"""Train the model using experience replay with high-quality examples"""
# Check if enough samples in memory
if len(self.memory) < self.batch_size:
return 0.0
# Get batch of experiences
batch = self.memory.get_batch(self.batch_size)
if batch is None:
return 0.0
states = torch.FloatTensor(batch['states']).to(self.device)
actions = torch.LongTensor(batch['actions']).to(self.device)
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
next_states = torch.FloatTensor(batch['next_states']).to(self.device)
dones = torch.FloatTensor(batch['dones']).to(self.device)
# Compute Q values
self.policy_net.train() # Set to training mode
# Get current Q values
if self.use_mixed_precision:
with torch.cuda.amp.autocast():
# Get current Q values
q_values, _, _, _ = self.policy_net(states)
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Compute target Q values
with torch.no_grad():
self.target_net.eval()
next_q_values, _, _, _ = self.target_net(next_states)
next_q = next_q_values.max(1)[0]
target_q = rewards + (1 - dones) * self.gamma * next_q
# Compute loss
loss = self.criterion(current_q, target_q)
# Perform backpropagation with mixed precision
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard precision training
# Get current Q values
q_values, _, _, _ = self.policy_net(states)
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Compute target Q values
with torch.no_grad():
self.target_net.eval()
next_q_values, _, _, _ = self.target_net(next_states)
next_q = next_q_values.max(1)[0]
target_q = rewards + (1 - dones) * self.gamma * next_q
# Compute loss
loss = self.criterion(current_q, target_q)
# Perform backpropagation
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
# Track loss
loss_value = loss.item()
self.losses.append(loss_value)
if len(self.losses) > 100:
self.losses = self.losses[-100:]
# Update target network
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
logger.info(f"Updated target network (step {self.update_count})")
# Decay epsilon
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
return loss_value
def save(self, path):
"""Save agent state and models"""
self.policy_net.save(f"{path}_policy")
self.target_net.save(f"{path}_target")
# Save agent state
torch.save({
'epsilon': self.epsilon,
'confidence_threshold': self.confidence_threshold,
'losses': self.losses,
'rewards': self.rewards,
'avg_reward': self.avg_reward,
'confidence_history': self.confidence_history,
'avg_confidence': self.avg_confidence,
'max_confidence': self.max_confidence,
'min_confidence': self.min_confidence,
'update_count': self.update_count
}, f"{path}_agent_state.pt")
logger.info(f"Agent state saved to {path}_agent_state.pt")
def load(self, path):
"""Load agent state and models"""
policy_loaded = self.policy_net.load(f"{path}_policy")
target_loaded = self.target_net.load(f"{path}_target")
# Load agent state if available
agent_state_path = f"{path}_agent_state.pt"
if os.path.exists(agent_state_path):
try:
state = torch.load(agent_state_path)
self.epsilon = state.get('epsilon', self.epsilon)
self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold)
self.policy_net.confidence_threshold = self.confidence_threshold
self.target_net.confidence_threshold = self.confidence_threshold
self.losses = state.get('losses', [])
self.rewards = state.get('rewards', [])
self.avg_reward = state.get('avg_reward', 0.0)
self.confidence_history = state.get('confidence_history', [])
self.avg_confidence = state.get('avg_confidence', 0.0)
self.max_confidence = state.get('max_confidence', 0.0)
self.min_confidence = state.get('min_confidence', 1.0)
self.update_count = state.get('update_count', 0)
logger.info(f"Agent state loaded from {agent_state_path}")
except Exception as e:
logger.error(f"Error loading agent state: {str(e)}")
return policy_loaded and target_loaded