266 lines
9.4 KiB
Python
266 lines
9.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
from typing import Tuple, List
|
|
import os
|
|
import sys
|
|
import logging
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
from NN.models.simple_cnn import CNNModelPyTorch
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DQNAgent:
|
|
"""
|
|
Deep Q-Network agent for trading
|
|
Uses CNN model as the base network
|
|
"""
|
|
def __init__(self,
|
|
state_size: int,
|
|
action_size: int,
|
|
window_size: int,
|
|
num_features: int,
|
|
timeframes: List[str],
|
|
learning_rate: float = 0.001,
|
|
gamma: float = 0.99,
|
|
epsilon: float = 1.0,
|
|
epsilon_min: float = 0.01,
|
|
epsilon_decay: float = 0.995,
|
|
memory_size: int = 10000,
|
|
batch_size: int = 64,
|
|
target_update: int = 10):
|
|
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.window_size = window_size
|
|
self.num_features = num_features
|
|
self.timeframes = timeframes
|
|
self.learning_rate = learning_rate
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.epsilon_min = epsilon_min
|
|
self.epsilon_decay = epsilon_decay
|
|
self.memory_size = memory_size
|
|
self.batch_size = batch_size
|
|
self.target_update = target_update
|
|
|
|
# Device configuration
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Initialize networks
|
|
self.policy_net = CNNModelPyTorch(
|
|
window_size=window_size,
|
|
num_features=num_features,
|
|
output_size=action_size,
|
|
timeframes=timeframes
|
|
).to(self.device)
|
|
|
|
self.target_net = CNNModelPyTorch(
|
|
window_size=window_size,
|
|
num_features=num_features,
|
|
output_size=action_size,
|
|
timeframes=timeframes
|
|
).to(self.device)
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Initialize optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
|
|
|
# Initialize memory
|
|
self.memory = deque(maxlen=memory_size)
|
|
|
|
# Special memory for extrema samples to use for targeted learning
|
|
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
|
|
|
|
# Training metrics
|
|
self.update_count = 0
|
|
self.losses = []
|
|
|
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
|
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
|
"""
|
|
Store experience in memory
|
|
|
|
Args:
|
|
state: Current state
|
|
action: Action taken
|
|
reward: Reward received
|
|
next_state: Next state
|
|
done: Whether episode is done
|
|
is_extrema: Whether this is a local extrema sample (for specialized learning)
|
|
"""
|
|
experience = (state, action, reward, next_state, done)
|
|
self.memory.append(experience)
|
|
|
|
# If this is an extrema sample, also add to specialized memory
|
|
if is_extrema:
|
|
self.extrema_memory.append(experience)
|
|
|
|
def act(self, state: np.ndarray) -> int:
|
|
"""Choose action using epsilon-greedy policy"""
|
|
if random.random() < self.epsilon:
|
|
return random.randrange(self.action_size)
|
|
|
|
with torch.no_grad():
|
|
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
action_probs, extrema_pred = self.policy_net(state)
|
|
return action_probs.argmax().item()
|
|
|
|
def replay(self, use_extrema=False) -> float:
|
|
"""
|
|
Train on a batch of experiences
|
|
|
|
Args:
|
|
use_extrema: Whether to include extrema samples in training
|
|
|
|
Returns:
|
|
float: Loss value
|
|
"""
|
|
if len(self.memory) < self.batch_size:
|
|
return 0.0
|
|
|
|
# Sample batch - mix regular and extrema samples
|
|
batch = []
|
|
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
|
|
# Get some extrema samples
|
|
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
|
|
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
|
|
|
|
# Get regular samples for the rest
|
|
regular_count = self.batch_size - extrema_count
|
|
regular_samples = random.sample(list(self.memory), regular_count)
|
|
|
|
# Combine samples
|
|
batch = extrema_samples + regular_samples
|
|
else:
|
|
# Standard sampling
|
|
batch = random.sample(self.memory, self.batch_size)
|
|
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
# Convert to tensors and move to device
|
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
|
actions = torch.LongTensor(actions).to(self.device)
|
|
rewards = torch.FloatTensor(rewards).to(self.device)
|
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
|
dones = torch.FloatTensor(dones).to(self.device)
|
|
|
|
# Get current Q values
|
|
current_q_values, extrema_pred = self.policy_net(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
|
|
|
# Get next Q values from target network
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.target_net(next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Compute Q-learning loss
|
|
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
|
|
|
# If we have extrema labels (not in this implementation yet),
|
|
# we could add an additional loss for extrema prediction
|
|
# This would require labels for whether each state is near an extrema
|
|
|
|
# Total loss is just Q-learning loss for now
|
|
loss = q_loss
|
|
|
|
# Optimize
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# Update target network if needed
|
|
self.update_count += 1
|
|
if self.update_count % self.target_update == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Decay epsilon
|
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
|
|
|
return loss.item()
|
|
|
|
def train_on_extrema(self, states, actions, rewards, next_states, dones):
|
|
"""
|
|
Special training method focused on extrema patterns
|
|
|
|
Args:
|
|
states: Array of states near extrema points
|
|
actions: Correct actions to take (buy at bottoms, sell at tops)
|
|
rewards: Rewards for each action
|
|
next_states: Next states
|
|
dones: Done flags
|
|
"""
|
|
if len(states) == 0:
|
|
return 0.0
|
|
|
|
# Convert to tensors
|
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
|
actions = torch.LongTensor(actions).to(self.device)
|
|
rewards = torch.FloatTensor(rewards).to(self.device)
|
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
|
dones = torch.FloatTensor(dones).to(self.device)
|
|
|
|
# Forward pass
|
|
current_q_values, extrema_pred = self.policy_net(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
|
|
|
# Get next Q values
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.target_net(next_states)
|
|
next_q_values = next_q_values.max(1)[0]
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
# Higher weight for extrema training
|
|
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
|
|
|
# Full loss is just Q-learning loss
|
|
loss = q_loss
|
|
|
|
# Optimize
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
return loss.item()
|
|
|
|
def save(self, path: str):
|
|
"""Save model and agent state"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
# Save policy network
|
|
self.policy_net.save(f"{path}_policy")
|
|
|
|
# Save target network
|
|
self.target_net.save(f"{path}_target")
|
|
|
|
# Save agent state
|
|
state = {
|
|
'epsilon': self.epsilon,
|
|
'update_count': self.update_count,
|
|
'losses': self.losses,
|
|
'optimizer_state': self.optimizer.state_dict()
|
|
}
|
|
torch.save(state, f"{path}_agent_state.pt")
|
|
|
|
def load(self, path: str):
|
|
"""Load model and agent state"""
|
|
# Load policy network
|
|
self.policy_net.load(f"{path}_policy")
|
|
|
|
# Load target network
|
|
self.target_net.load(f"{path}_target")
|
|
|
|
# Load agent state
|
|
state = torch.load(f"{path}_agent_state.pt")
|
|
self.epsilon = state['epsilon']
|
|
self.update_count = state['update_count']
|
|
self.losses = state['losses']
|
|
self.optimizer.load_state_dict(state['optimizer_state']) |