gogo2/NN/models/dqn_agent.py
Dobromir Popov 4eac14022c RL training
2025-03-31 03:31:54 +03:00

170 lines
5.9 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
import os
import sys
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from NN.models.simple_cnn import CNNModelPyTorch
class DQNAgent:
"""
Deep Q-Network agent for trading
Uses CNN model as the base network
"""
def __init__(self,
state_size: int,
action_size: int,
window_size: int,
num_features: int,
timeframes: List[str],
learning_rate: float = 0.001,
gamma: float = 0.99,
epsilon: float = 1.0,
epsilon_min: float = 0.01,
epsilon_decay: float = 0.995,
memory_size: int = 10000,
batch_size: int = 64,
target_update: int = 10):
self.state_size = state_size
self.action_size = action_size
self.window_size = window_size
self.num_features = num_features
self.timeframes = timeframes
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.memory_size = memory_size
self.batch_size = batch_size
self.target_update = target_update
# Device configuration
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize networks
self.policy_net = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=action_size,
timeframes=timeframes
).to(self.device)
self.target_net = CNNModelPyTorch(
window_size=window_size,
num_features=num_features,
output_size=action_size,
timeframes=timeframes
).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
# Initialize optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
# Initialize memory
self.memory = deque(maxlen=memory_size)
# Training metrics
self.update_count = 0
self.losses = []
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool):
"""Store experience in memory"""
self.memory.append((state, action, reward, next_state, done))
def act(self, state: np.ndarray) -> int:
"""Choose action using epsilon-greedy policy"""
if random.random() < self.epsilon:
return random.randrange(self.action_size)
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_probs, _ = self.policy_net(state)
return action_probs.argmax().item()
def replay(self) -> float:
"""Train on a batch of experiences"""
if len(self.memory) < self.batch_size:
return 0.0
# Sample batch
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
# Convert to tensors and move to device
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# Get current Q values
current_q_values, _ = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values from target network
with torch.no_grad():
next_q_values, _ = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
def save(self, path: str):
"""Save model and agent state"""
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save policy network
self.policy_net.save(f"{path}_policy")
# Save target network
self.target_net.save(f"{path}_target")
# Save agent state
state = {
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
'optimizer_state': self.optimizer.state_dict()
}
torch.save(state, f"{path}_agent_state.pt")
def load(self, path: str):
"""Load model and agent state"""
# Load policy network
self.policy_net.load(f"{path}_policy")
# Load target network
self.target_net.load(f"{path}_target")
# Load agent state
state = torch.load(f"{path}_agent_state.pt")
self.epsilon = state['epsilon']
self.update_count = state['update_count']
self.losses = state['losses']
self.optimizer.load_state_dict(state['optimizer_state'])