gogo2/NN/models/dqn_agent.py
Dobromir Popov 73c5ecb0d2 enhancements
2025-04-01 13:46:53 +03:00

266 lines
9.4 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
import os
import sys
import logging
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from NN.models.simple_cnn import CNNModelPyTorch
# Configure logger
logger = logging.getLogger(__name__)
class DQNAgent:
"""
Deep Q-Network agent for trading
Uses CNN model as the base network
"""
def __init__(self,
state_size: int,
action_size: int,
window_size: int,
num_features: int,
timeframes: List[str],
learning_rate: float = 0.001,
gamma: float = 0.99,
epsilon: float = 1.0,
epsilon_min: float = 0.01,
epsilon_decay: float = 0.995,
memory_size: int = 10000,
batch_size: int = 64,
target_update: int = 10):
self.state_size = state_size
self.action_size = action_size
self.window_size = window_size
self.num_features = num_features
self.timeframes = timeframes
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.memory_size = memory_size
self.batch_size = batch_size
self.target_update = target_update
# Device configuration
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize networks
self.policy_net = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=action_size,
timeframes=timeframes
).to(self.device)
self.target_net = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=action_size,
timeframes=timeframes
).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
# Initialize optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
# Initialize memory
self.memory = deque(maxlen=memory_size)
# Special memory for extrema samples to use for targeted learning
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
# Training metrics
self.update_count = 0
self.losses = []
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""
Store experience in memory
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
done: Whether episode is done
is_extrema: Whether this is a local extrema sample (for specialized learning)
"""
experience = (state, action, reward, next_state, done)
self.memory.append(experience)
# If this is an extrema sample, also add to specialized memory
if is_extrema:
self.extrema_memory.append(experience)
def act(self, state: np.ndarray) -> int:
"""Choose action using epsilon-greedy policy"""
if random.random() < self.epsilon:
return random.randrange(self.action_size)
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_probs, extrema_pred = self.policy_net(state)
return action_probs.argmax().item()
def replay(self, use_extrema=False) -> float:
"""
Train on a batch of experiences
Args:
use_extrema: Whether to include extrema samples in training
Returns:
float: Loss value
"""
if len(self.memory) < self.batch_size:
return 0.0
# Sample batch - mix regular and extrema samples
batch = []
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
# Get some extrema samples
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
# Get regular samples for the rest
regular_count = self.batch_size - extrema_count
regular_samples = random.sample(list(self.memory), regular_count)
# Combine samples
batch = extrema_samples + regular_samples
else:
# Standard sampling
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
# Convert to tensors and move to device
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# Get current Q values
current_q_values, extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values from target network
with torch.no_grad():
next_q_values, _ = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute Q-learning loss
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# If we have extrema labels (not in this implementation yet),
# we could add an additional loss for extrema prediction
# This would require labels for whether each state is near an extrema
# Total loss is just Q-learning loss for now
loss = q_loss
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training method focused on extrema patterns
Args:
states: Array of states near extrema points
actions: Correct actions to take (buy at bottoms, sell at tops)
rewards: Rewards for each action
next_states: Next states
dones: Done flags
"""
if len(states) == 0:
return 0.0
# Convert to tensors
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# Forward pass
current_q_values, extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values
with torch.no_grad():
next_q_values, _ = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Higher weight for extrema training
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# Full loss is just Q-learning loss
loss = q_loss
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def save(self, path: str):
"""Save model and agent state"""
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save policy network
self.policy_net.save(f"{path}_policy")
# Save target network
self.target_net.save(f"{path}_target")
# Save agent state
state = {
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
'optimizer_state': self.optimizer.state_dict()
}
torch.save(state, f"{path}_agent_state.pt")
def load(self, path: str):
"""Load model and agent state"""
# Load policy network
self.policy_net.load(f"{path}_policy")
# Load target network
self.target_net.load(f"{path}_target")
# Load agent state
state = torch.load(f"{path}_agent_state.pt")
self.epsilon = state['epsilon']
self.update_count = state['update_count']
self.losses = state['losses']
self.optimizer.load_state_dict(state['optimizer_state'])