implementations
This commit is contained in:
49
crypto/gogo/training/rl_agent.py
Normal file
49
crypto/gogo/training/rl_agent.py
Normal file
@ -0,0 +1,49 @@
|
||||
# training/rl_agent.py
|
||||
import random
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class ContinuousRLAgent:
|
||||
def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.replay_buffer = replay_buffer
|
||||
self.batch_size = batch_size
|
||||
self.gamma = gamma
|
||||
|
||||
def act(self, state, epsilon=0.0):
|
||||
"""
|
||||
Select an action based on the state, using an epsilon-greedy policy.
|
||||
"""
|
||||
if random.random() < epsilon:
|
||||
# Exploration: choose a random action.
|
||||
action = np.random.choice([0, 1, 2]) # SELL, HOLD, BUY
|
||||
else:
|
||||
# Exploitation: choose the action with the highest Q-value.
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
|
||||
q_values = self.model(state_tensor)
|
||||
action = torch.argmax(q_values).item()
|
||||
return action
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(self, capacity):
|
||||
self.buffer = deque(maxlen=capacity)
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
"""
|
||||
Store an experience tuple into the replay buffer.
|
||||
"""
|
||||
self.buffer.append((state, action, reward, next_state, done))
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""
|
||||
Randomly sample a batch of experiences from the replay buffer.
|
||||
"""
|
||||
batch = random.sample(self.buffer, batch_size)
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
return states, actions, rewards, next_states, dones
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
Reference in New Issue
Block a user