RL training
This commit is contained in:
170
NN/models/dqn_agent.py
Normal file
170
NN/models/dqn_agent.py
Normal file
@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from NN.models.simple_cnn import CNNModelPyTorch
|
||||
|
||||
class DQNAgent:
|
||||
"""
|
||||
Deep Q-Network agent for trading
|
||||
Uses CNN model as the base network
|
||||
"""
|
||||
def __init__(self,
|
||||
state_size: int,
|
||||
action_size: int,
|
||||
window_size: int,
|
||||
num_features: int,
|
||||
timeframes: List[str],
|
||||
learning_rate: float = 0.001,
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.01,
|
||||
epsilon_decay: float = 0.995,
|
||||
memory_size: int = 10000,
|
||||
batch_size: int = 64,
|
||||
target_update: int = 10):
|
||||
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.timeframes = timeframes
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.memory_size = memory_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
|
||||
# Device configuration
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Initialize networks
|
||||
self.policy_net = CNNModelPyTorch(
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
output_size=action_size,
|
||||
timeframes=timeframes
|
||||
).to(self.device)
|
||||
|
||||
self.target_net = CNNModelPyTorch(
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
output_size=action_size,
|
||||
timeframes=timeframes
|
||||
).to(self.device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Initialize optimizer
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
||||
|
||||
# Initialize memory
|
||||
self.memory = deque(maxlen=memory_size)
|
||||
|
||||
# Training metrics
|
||||
self.update_count = 0
|
||||
self.losses = []
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool):
|
||||
"""Store experience in memory"""
|
||||
self.memory.append((state, action, reward, next_state, done))
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
if random.random() < self.epsilon:
|
||||
return random.randrange(self.action_size)
|
||||
|
||||
with torch.no_grad():
|
||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
action_probs, _ = self.policy_net(state)
|
||||
return action_probs.argmax().item()
|
||||
|
||||
def replay(self) -> float:
|
||||
"""Train on a batch of experiences"""
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample batch
|
||||
batch = random.sample(self.memory, self.batch_size)
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
|
||||
# Convert to tensors and move to device
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(actions).to(self.device)
|
||||
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, _ = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute loss
|
||||
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Decay epsilon
|
||||
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model and agent state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Save policy network
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
|
||||
# Save target network
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict()
|
||||
}
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load model and agent state"""
|
||||
# Load policy network
|
||||
self.policy_net.load(f"{path}_policy")
|
||||
|
||||
# Load target network
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
state = torch.load(f"{path}_agent_state.pt")
|
||||
self.epsilon = state['epsilon']
|
||||
self.update_count = state['update_count']
|
||||
self.losses = state['losses']
|
||||
self.optimizer.load_state_dict(state['optimizer_state'])
|
Reference in New Issue
Block a user