170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
from typing import Tuple, List
|
|
import os
|
|
import sys
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
from NN.models.simple_cnn import CNNModelPyTorch
|
|
|
|
class DQNAgent:
|
|
"""
|
|
Deep Q-Network agent for trading
|
|
Uses CNN model as the base network
|
|
"""
|
|
def __init__(self,
|
|
state_size: int,
|
|
action_size: int,
|
|
window_size: int,
|
|
num_features: int,
|
|
timeframes: List[str],
|
|
learning_rate: float = 0.001,
|
|
gamma: float = 0.99,
|
|
epsilon: float = 1.0,
|
|
epsilon_min: float = 0.01,
|
|
epsilon_decay: float = 0.995,
|
|
memory_size: int = 10000,
|
|
batch_size: int = 64,
|
|
target_update: int = 10):
|
|
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.window_size = window_size
|
|
self.num_features = num_features
|
|
self.timeframes = timeframes
|
|
self.learning_rate = learning_rate
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.epsilon_min = epsilon_min
|
|
self.epsilon_decay = epsilon_decay
|
|
self.memory_size = memory_size
|
|
self.batch_size = batch_size
|
|
self.target_update = target_update
|
|
|
|
# Device configuration
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Initialize networks
|
|
self.policy_net = CNNModelPyTorch(
|
|
window_size=window_size,
|
|
num_features=num_features,
|
|
output_size=action_size,
|
|
timeframes=timeframes
|
|
).to(self.device)
|
|
|
|
self.target_net = CNNModelPyTorch(
|
|
window_size=window_size,
|
|
num_features=num_features,
|
|
output_size=action_size,
|
|
timeframes=timeframes
|
|
).to(self.device)
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Initialize optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
|
|
|
# Initialize memory
|
|
self.memory = deque(maxlen=memory_size)
|
|
|
|
# Training metrics
|
|
self.update_count = 0
|
|
self.losses = []
|
|
|
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
|
next_state: np.ndarray, done: bool):
|
|
"""Store experience in memory"""
|
|
self.memory.append((state, action, reward, next_state, done))
|
|
|
|
def act(self, state: np.ndarray) -> int:
|
|
"""Choose action using epsilon-greedy policy"""
|
|
if random.random() < self.epsilon:
|
|
return random.randrange(self.action_size)
|
|
|
|
with torch.no_grad():
|
|
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
action_probs, _ = self.policy_net(state)
|
|
return action_probs.argmax().item()
|
|
|
|
def replay(self) -> float:
|
|
"""Train on a batch of experiences"""
|
|
if len(self.memory) < self.batch_size:
|
|
return 0.0
|
|
|
|
# Sample batch
|
|
batch = random.sample(self.memory, self.batch_size)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
# Convert to tensors and move to device
|
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
|
actions = torch.LongTensor(actions).to(self.device)
|
|
rewards = torch.FloatTensor(rewards).to(self.device)
|
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
|
dones = torch.FloatTensor(dones).to(self.device)
|
|
|
|
# Get current Q values
|
|
current_q_values, _ = self.policy_net(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
|
|
|
# Get next Q values from target network
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.target_net(next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Compute loss
|
|
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
|
|
|
# Optimize
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# Update target network if needed
|
|
self.update_count += 1
|
|
if self.update_count % self.target_update == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Decay epsilon
|
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
|
|
|
return loss.item()
|
|
|
|
def save(self, path: str):
|
|
"""Save model and agent state"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
# Save policy network
|
|
self.policy_net.save(f"{path}_policy")
|
|
|
|
# Save target network
|
|
self.target_net.save(f"{path}_target")
|
|
|
|
# Save agent state
|
|
state = {
|
|
'epsilon': self.epsilon,
|
|
'update_count': self.update_count,
|
|
'losses': self.losses,
|
|
'optimizer_state': self.optimizer.state_dict()
|
|
}
|
|
torch.save(state, f"{path}_agent_state.pt")
|
|
|
|
def load(self, path: str):
|
|
"""Load model and agent state"""
|
|
# Load policy network
|
|
self.policy_net.load(f"{path}_policy")
|
|
|
|
# Load target network
|
|
self.target_net.load(f"{path}_target")
|
|
|
|
# Load agent state
|
|
state = torch.load(f"{path}_agent_state.pt")
|
|
self.epsilon = state['epsilon']
|
|
self.update_count = state['update_count']
|
|
self.losses = state['losses']
|
|
self.optimizer.load_state_dict(state['optimizer_state']) |