gogo2/NN/utils/trading_env.py
Dobromir Popov 4eac14022c RL training
2025-03-31 03:31:54 +03:00

162 lines
5.7 KiB
Python

import numpy as np
import gym
from gym import spaces
from typing import Dict, Tuple, List
import pandas as pd
class TradingEnvironment(gym.Env):
"""
Custom trading environment for reinforcement learning
"""
def __init__(self,
data: pd.DataFrame,
initial_balance: float = 100.0,
fee_rate: float = 0.0002,
max_steps: int = 1000):
super(TradingEnvironment, self).__init__()
self.data = data
self.initial_balance = initial_balance
self.fee_rate = fee_rate
self.max_steps = max_steps
# Action space: 0 (SELL), 1 (HOLD), 2 (BUY)
self.action_space = spaces.Discrete(3)
# Observation space: price data, technical indicators, and account state
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(data.shape[1],), # Number of features
dtype=np.float32
)
# Initialize state
self.reset()
def reset(self) -> np.ndarray:
"""Reset the environment to initial state"""
self.current_step = 0
self.balance = self.initial_balance
self.position = 0 # 0: no position, 1: long position
self.entry_price = 0
self.total_trades = 0
self.winning_trades = 0
self.total_pnl = 0
self.balance_history = [self.initial_balance]
self.max_balance = self.initial_balance
return self._get_observation()
def _get_observation(self) -> np.ndarray:
"""Get current observation state"""
return self.data.iloc[self.current_step].values
def _calculate_reward(self, action: int) -> float:
"""Calculate reward based on action and outcome"""
current_price = self.data.iloc[self.current_step]['close']
# If we have an open position
if self.position != 0:
# Calculate PnL
pnl = self.position * (current_price - self.entry_price) / self.entry_price
fees = self.fee_rate * 2 # Entry and exit fees
# Close position
if (action == 0 and self.position > 0) or (action == 2 and self.position < 0):
net_pnl = pnl - fees
self.total_pnl += net_pnl
self.balance *= (1 + net_pnl)
self.balance_history.append(self.balance)
self.max_balance = max(self.max_balance, self.balance)
self.total_trades += 1
if net_pnl > 0:
self.winning_trades += 1
# Reward based on PnL
reward = net_pnl * 100 # Scale up for better learning
# Additional reward for win rate
win_rate = self.winning_trades / max(1, self.total_trades)
reward += win_rate * 0.1
self.position = 0
return reward
# Hold position
return pnl * 0.1 # Small reward for holding profitable positions
# No position
if action == 1: # HOLD
return 0
# Open new position
if action in [0, 2]: # SELL or BUY
self.position = -1 if action == 0 else 1
self.entry_price = current_price
return -self.fee_rate # Small penalty for trading
return 0
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
"""Execute one step in the environment"""
# Calculate reward
reward = self._calculate_reward(action)
# Move to next step
self.current_step += 1
# Check if episode is done
done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1)
# Get next observation
observation = self._get_observation()
# Calculate max drawdown
max_drawdown = 0
if len(self.balance_history) > 1:
peak = self.balance_history[0]
for balance in self.balance_history:
peak = max(peak, balance)
drawdown = (peak - balance) / peak
max_drawdown = max(max_drawdown, drawdown)
# Additional info
info = {
'balance': self.balance,
'position': self.position,
'total_trades': self.total_trades,
'win_rate': self.winning_trades / max(1, self.total_trades),
'total_pnl': self.total_pnl,
'max_drawdown': max_drawdown
}
return observation, reward, done, info
def render(self, mode='human'):
"""Render the environment"""
if mode == 'human':
print(f"Step: {self.current_step}")
print(f"Balance: ${self.balance:.2f}")
print(f"Position: {self.position}")
print(f"Total Trades: {self.total_trades}")
print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}")
print(f"Total PnL: ${self.total_pnl:.2f}")
print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}")
print("-" * 50)
def _calculate_max_drawdown(self):
"""Calculate maximum drawdown from balance history"""
if len(self.balance_history) <= 1:
return 0.0
peak = self.balance_history[0]
max_drawdown = 0.0
for balance in self.balance_history:
peak = max(peak, balance)
drawdown = (peak - balance) / peak
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown