430 lines
16 KiB
Python
430 lines
16 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
import os
|
|
import sys
|
|
import pandas as pd
|
|
import gym
|
|
import json
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from NN.utils.data_interface import DataInterface
|
|
from NN.utils.trading_env import TradingEnvironment
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from NN.utils.signal_interpreter import SignalInterpreter
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('rl_training.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
class RLTradingEnvironment(gym.Env):
|
|
"""
|
|
Reinforcement Learning environment for trading with technical indicators
|
|
from multiple timeframes
|
|
"""
|
|
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
|
super().__init__()
|
|
|
|
# Initialize attributes before parent class
|
|
self.window_size = window_size
|
|
self.num_features = features_1m.shape[1] - 1 # Exclude close price
|
|
self.num_timeframes = 3 # 1m, 5m, 15m
|
|
self.feature_dim = self.num_features * self.num_timeframes
|
|
|
|
# Store features from different timeframes
|
|
self.features_1m = features_1m
|
|
self.features_5m = features_5m
|
|
self.features_15m = features_15m
|
|
|
|
# Trading parameters
|
|
self.initial_balance = 1.0
|
|
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
|
|
self.min_trade_interval = min_trade_interval # Minimum steps between trades
|
|
|
|
# Define action and observation spaces
|
|
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
|
|
self.observation_space = gym.spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self.window_size, self.feature_dim),
|
|
dtype=np.float32
|
|
)
|
|
|
|
# State variables
|
|
self.reset()
|
|
|
|
# Callback for visualization or external monitoring
|
|
self.action_callback = None
|
|
|
|
def reset(self):
|
|
"""Reset the environment to initial state"""
|
|
self.balance = self.initial_balance
|
|
self.position = 0.0 # Amount of asset held
|
|
self.current_step = self.window_size
|
|
self.trades = 0
|
|
self.wins = 0
|
|
self.losses = 0
|
|
self.trade_history = []
|
|
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
|
|
|
|
# Get initial observation
|
|
observation = self._get_observation()
|
|
return observation
|
|
|
|
def _get_observation(self):
|
|
"""
|
|
Get the current state observation.
|
|
Combine features from multiple timeframes, reshaped for the CNN.
|
|
"""
|
|
# Calculate indices for each timeframe
|
|
idx_1m = self.current_step
|
|
idx_5m = idx_1m // 5
|
|
idx_15m = idx_1m // 15
|
|
|
|
# Extract feature windows from each timeframe
|
|
window_1m = self.features_1m[idx_1m - self.window_size:idx_1m]
|
|
|
|
# Handle 5m timeframe
|
|
start_5m = max(0, idx_5m - self.window_size)
|
|
window_5m = self.features_5m[start_5m:idx_5m]
|
|
|
|
# Handle 15m timeframe
|
|
start_15m = max(0, idx_15m - self.window_size)
|
|
window_15m = self.features_15m[start_15m:idx_15m]
|
|
|
|
# Pad if needed (for 5m and 15m)
|
|
if len(window_5m) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_5m), window_5m.shape[1]))
|
|
window_5m = np.vstack([padding, window_5m])
|
|
|
|
if len(window_15m) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_15m), window_15m.shape[1]))
|
|
window_15m = np.vstack([padding, window_15m])
|
|
|
|
# Combine features from all timeframes
|
|
combined_features = np.hstack([
|
|
window_1m.reshape(self.window_size, -1),
|
|
window_5m.reshape(self.window_size, -1),
|
|
window_15m.reshape(self.window_size, -1)
|
|
])
|
|
|
|
# Convert to float32 and handle any NaN values
|
|
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
|
|
|
|
return combined_features
|
|
|
|
def step(self, action):
|
|
"""
|
|
Take an action in the environment and return the next state, reward, done flag, and info
|
|
|
|
Args:
|
|
action (int): 0 = Buy, 1 = Sell, 2 = Hold
|
|
|
|
Returns:
|
|
tuple: (observation, reward, done, info)
|
|
"""
|
|
# Get current and next price
|
|
current_price = self.features_1m[self.current_step, -1] # Close price is last column
|
|
next_price = self.features_1m[self.current_step + 1, -1]
|
|
|
|
# Handle zero or negative prices
|
|
if current_price <= 0:
|
|
current_price = 1e-8 # Small positive number
|
|
if next_price <= 0:
|
|
next_price = current_price # Use current price if next price is invalid
|
|
|
|
price_change = (next_price - current_price) / current_price
|
|
|
|
# Default reward is slightly negative to discourage inaction
|
|
reward = -0.0001
|
|
done = False
|
|
profit_pct = None # Initialize profit_pct variable
|
|
|
|
# Check if enough time has passed since last trade
|
|
trade_interval = self.current_step - self.last_trade_step
|
|
trade_interval_penalty = 0
|
|
|
|
# Execute action
|
|
if action == 0: # BUY
|
|
if self.position == 0: # Only buy if not already in position
|
|
# Apply extra penalty for trading too frequently
|
|
if trade_interval < self.min_trade_interval:
|
|
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
|
|
# Still allow the trade but with penalty
|
|
|
|
self.position = self.balance * (1 - self.trading_fee)
|
|
self.balance = 0
|
|
self.trades += 1
|
|
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
|
|
self.trade_entry_price = current_price
|
|
self.last_trade_step = self.current_step
|
|
|
|
elif action == 1: # SELL
|
|
if self.position > 0: # Only sell if in position
|
|
# Apply extra penalty for trading too frequently
|
|
if trade_interval < self.min_trade_interval:
|
|
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
|
|
# Still allow the trade but with penalty
|
|
|
|
# Calculate position value at current price
|
|
position_value = self.position * (1 + price_change)
|
|
self.balance = position_value * (1 - self.trading_fee)
|
|
|
|
# Calculate profit/loss from trade
|
|
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
|
|
# Scale reward by profit percentage and apply trade interval penalty
|
|
reward = (profit_pct * 10) + trade_interval_penalty
|
|
|
|
# Update win/loss count
|
|
if profit_pct > 0:
|
|
self.wins += 1
|
|
else:
|
|
self.losses += 1
|
|
|
|
# Record trade
|
|
self.trade_history.append({
|
|
'entry_price': self.trade_entry_price,
|
|
'exit_price': next_price,
|
|
'profit_pct': profit_pct,
|
|
'trade_interval': trade_interval
|
|
})
|
|
|
|
# Reset position and update last trade step
|
|
self.position = 0
|
|
self.last_trade_step = self.current_step
|
|
|
|
# else: (action == 2 - HOLD) - no position change
|
|
|
|
# Move to next step
|
|
self.current_step += 1
|
|
|
|
# Check if done
|
|
if self.current_step >= len(self.features_1m) - 1:
|
|
done = True
|
|
|
|
# Apply final evaluation
|
|
if self.position > 0:
|
|
# Force close position at the end
|
|
position_value = self.position * (1 + price_change)
|
|
self.balance = position_value * (1 - self.trading_fee)
|
|
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
|
|
reward += profit_pct * 10
|
|
|
|
# Update win/loss count
|
|
if profit_pct > 0:
|
|
self.wins += 1
|
|
else:
|
|
self.losses += 1
|
|
|
|
# Get the next observation
|
|
observation = self._get_observation()
|
|
|
|
# Calculate metrics for info
|
|
total_value = self.balance + self.position * next_price
|
|
gain = (total_value - self.initial_balance) / self.initial_balance
|
|
self.win_rate = self.wins / max(1, self.trades)
|
|
|
|
info = {
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'total_value': total_value,
|
|
'gain': gain,
|
|
'trades': self.trades,
|
|
'win_rate': self.win_rate,
|
|
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
|
|
'current_price': current_price,
|
|
'next_price': next_price
|
|
}
|
|
|
|
# Call the callback if it exists
|
|
if self.action_callback:
|
|
self.action_callback(action, current_price, reward, info)
|
|
|
|
return observation, reward, done, info
|
|
|
|
def set_action_callback(self, callback):
|
|
"""
|
|
Set a callback function to be called after each action
|
|
|
|
Args:
|
|
callback: Function with signature (action, price, reward, info)
|
|
"""
|
|
self.action_callback = callback
|
|
|
|
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
|
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
|
|
"""
|
|
Train DQN agent for RL-based trading with extended training and monitoring
|
|
|
|
Args:
|
|
env_class: Optional environment class to use, defaults to RLTradingEnvironment
|
|
num_episodes: Number of episodes to train
|
|
max_steps: Maximum steps per episode
|
|
save_path: Path to save the model
|
|
action_callback: Optional callback for each action (step, action, price, reward, info)
|
|
episode_callback: Optional callback after each episode (episode, reward, info)
|
|
symbol: Trading pair symbol (e.g., "BTC/USDT")
|
|
|
|
Returns:
|
|
DQNAgent: The trained agent
|
|
"""
|
|
import pandas as pd
|
|
from NN.utils.data_interface import DataInterface
|
|
|
|
logger.info("Starting DQN training for RL trading")
|
|
|
|
# Create data interface with specified symbol
|
|
data_interface = DataInterface(symbol=symbol)
|
|
|
|
# Load and preprocess data
|
|
logger.info(f"Loading data from multiple timeframes for {symbol}")
|
|
features_1m = data_interface.get_training_data("1m", n_candles=2000)
|
|
features_5m = data_interface.get_training_data("5m", n_candles=1000)
|
|
features_15m = data_interface.get_training_data("15m", n_candles=500)
|
|
|
|
# Check if we have all the data
|
|
if features_1m is None or features_5m is None or features_15m is None:
|
|
logger.error("Failed to load training data from one or more timeframes")
|
|
return None
|
|
|
|
# If data is a DataFrame, convert to numpy array excluding the timestamp column
|
|
if isinstance(features_1m, pd.DataFrame):
|
|
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
|
|
if isinstance(features_5m, pd.DataFrame):
|
|
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
|
|
if isinstance(features_15m, pd.DataFrame):
|
|
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
|
|
|
|
# Initialize environment or use provided class
|
|
if env_class is None:
|
|
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
|
|
else:
|
|
env = env_class(features_1m, features_5m, features_15m)
|
|
|
|
# Set action callback if provided
|
|
if action_callback:
|
|
def step_callback(action, price, reward, info):
|
|
action_callback(env.current_step, action, price, reward, info)
|
|
env.set_action_callback(step_callback)
|
|
|
|
# Initialize agent
|
|
window_size = env.window_size
|
|
num_features = env.num_features * env.num_timeframes
|
|
action_size = env.action_space.n
|
|
timeframes = ['1m', '5m', '15m'] # Match the timeframes from the environment
|
|
|
|
agent = DQNAgent(
|
|
state_size=window_size * num_features,
|
|
action_size=action_size,
|
|
window_size=window_size,
|
|
num_features=env.num_features,
|
|
timeframes=timeframes,
|
|
memory_size=100000,
|
|
batch_size=64,
|
|
learning_rate=0.0001,
|
|
gamma=0.99,
|
|
epsilon=1.0,
|
|
epsilon_min=0.01,
|
|
epsilon_decay=0.995
|
|
)
|
|
|
|
# Training variables
|
|
best_reward = -float('inf')
|
|
episode_rewards = []
|
|
|
|
# TensorBoard writer for logging
|
|
writer = SummaryWriter(log_dir=f'runs/rl_trading_{int(time.time())}')
|
|
|
|
# Main training loop
|
|
logger.info(f"Starting training for {num_episodes} episodes...")
|
|
logger.info(f"Starting training on device: {agent.device}")
|
|
|
|
try:
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
total_reward = 0
|
|
|
|
for step in range(max_steps):
|
|
# Select action
|
|
action = agent.act(state)
|
|
|
|
# Take action and observe next state and reward
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store the experience in memory
|
|
agent.remember(state, action, reward, next_state, done)
|
|
|
|
# Update state and reward
|
|
state = next_state
|
|
total_reward += reward
|
|
|
|
# Train the agent by sampling from memory
|
|
if len(agent.memory) >= agent.batch_size:
|
|
loss = agent.replay()
|
|
|
|
if done or step == max_steps - 1:
|
|
break
|
|
|
|
# Track rewards
|
|
episode_rewards.append(total_reward)
|
|
|
|
# Log progress
|
|
avg_reward = np.mean(episode_rewards[-100:])
|
|
logger.info(f"Episode {episode}/{num_episodes} - Reward: {total_reward:.4f}, " +
|
|
f"Avg (100): {avg_reward:.4f}, Epsilon: {agent.epsilon:.4f}")
|
|
|
|
# Calculate trading metrics
|
|
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
|
|
trades = env.trades if hasattr(env, 'trades') else 0
|
|
|
|
# Log to TensorBoard
|
|
writer.add_scalar('Reward/Episode', total_reward, episode)
|
|
writer.add_scalar('Reward/Average100', avg_reward, episode)
|
|
writer.add_scalar('Trade/WinRate', win_rate, episode)
|
|
writer.add_scalar('Trade/Count', trades, episode)
|
|
|
|
# Save best model
|
|
if avg_reward > best_reward and episode > 10:
|
|
logger.info(f"New best average reward: {avg_reward:.4f}, saving model")
|
|
agent.save(save_path)
|
|
best_reward = avg_reward
|
|
|
|
# Periodic save every 100 episodes
|
|
if episode % 100 == 0 and episode > 0:
|
|
agent.save(f"{save_path}_episode_{episode}")
|
|
|
|
# Call episode callback if provided
|
|
if episode_callback:
|
|
# Add environment to info dict to use for extrema training
|
|
info_with_env = info.copy()
|
|
info_with_env['env'] = env
|
|
episode_callback(episode, total_reward, info_with_env)
|
|
|
|
# Final save
|
|
logger.info("Training completed, saving final model")
|
|
agent.save(f"{save_path}_final")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Close TensorBoard writer
|
|
writer.close()
|
|
|
|
return agent
|
|
|
|
if __name__ == "__main__":
|
|
train_rl() |