enhancements
This commit is contained in:
@ -7,12 +7,16 @@ import random
|
||||
from typing import Tuple, List
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from NN.models.simple_cnn import CNNModelPyTorch
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DQNAgent:
|
||||
"""
|
||||
Deep Q-Network agent for trading
|
||||
@ -72,14 +76,32 @@ class DQNAgent:
|
||||
# Initialize memory
|
||||
self.memory = deque(maxlen=memory_size)
|
||||
|
||||
# Special memory for extrema samples to use for targeted learning
|
||||
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
|
||||
|
||||
# Training metrics
|
||||
self.update_count = 0
|
||||
self.losses = []
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool):
|
||||
"""Store experience in memory"""
|
||||
self.memory.append((state, action, reward, next_state, done))
|
||||
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
||||
"""
|
||||
Store experience in memory
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
next_state: Next state
|
||||
done: Whether episode is done
|
||||
is_extrema: Whether this is a local extrema sample (for specialized learning)
|
||||
"""
|
||||
experience = (state, action, reward, next_state, done)
|
||||
self.memory.append(experience)
|
||||
|
||||
# If this is an extrema sample, also add to specialized memory
|
||||
if is_extrema:
|
||||
self.extrema_memory.append(experience)
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
@ -88,16 +110,39 @@ class DQNAgent:
|
||||
|
||||
with torch.no_grad():
|
||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
action_probs, _ = self.policy_net(state)
|
||||
action_probs, extrema_pred = self.policy_net(state)
|
||||
return action_probs.argmax().item()
|
||||
|
||||
def replay(self) -> float:
|
||||
"""Train on a batch of experiences"""
|
||||
def replay(self, use_extrema=False) -> float:
|
||||
"""
|
||||
Train on a batch of experiences
|
||||
|
||||
Args:
|
||||
use_extrema: Whether to include extrema samples in training
|
||||
|
||||
Returns:
|
||||
float: Loss value
|
||||
"""
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample batch
|
||||
batch = random.sample(self.memory, self.batch_size)
|
||||
# Sample batch - mix regular and extrema samples
|
||||
batch = []
|
||||
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
|
||||
# Get some extrema samples
|
||||
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
|
||||
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
|
||||
|
||||
# Get regular samples for the rest
|
||||
regular_count = self.batch_size - extrema_count
|
||||
regular_samples = random.sample(list(self.memory), regular_count)
|
||||
|
||||
# Combine samples
|
||||
batch = extrema_samples + regular_samples
|
||||
else:
|
||||
# Standard sampling
|
||||
batch = random.sample(self.memory, self.batch_size)
|
||||
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
|
||||
# Convert to tensors and move to device
|
||||
@ -108,7 +153,7 @@ class DQNAgent:
|
||||
dones = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, _ = self.policy_net(states)
|
||||
current_q_values, extrema_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Get next Q values from target network
|
||||
@ -117,8 +162,15 @@ class DQNAgent:
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute loss
|
||||
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
# Compute Q-learning loss
|
||||
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
|
||||
# If we have extrema labels (not in this implementation yet),
|
||||
# we could add an additional loss for extrema prediction
|
||||
# This would require labels for whether each state is near an extrema
|
||||
|
||||
# Total loss is just Q-learning loss for now
|
||||
loss = q_loss
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
@ -135,6 +187,50 @@ class DQNAgent:
|
||||
|
||||
return loss.item()
|
||||
|
||||
def train_on_extrema(self, states, actions, rewards, next_states, dones):
|
||||
"""
|
||||
Special training method focused on extrema patterns
|
||||
|
||||
Args:
|
||||
states: Array of states near extrema points
|
||||
actions: Correct actions to take (buy at bottoms, sell at tops)
|
||||
rewards: Rewards for each action
|
||||
next_states: Next states
|
||||
dones: Done flags
|
||||
"""
|
||||
if len(states) == 0:
|
||||
return 0.0
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(actions).to(self.device)
|
||||
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
current_q_values, extrema_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Get next Q values
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Higher weight for extrema training
|
||||
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
|
||||
# Full loss is just Q-learning loss
|
||||
loss = q_loss
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model and agent state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
Reference in New Issue
Block a user