import numpy as np import gym from gym import spaces from typing import Dict, Tuple, List import pandas as pd class TradingEnvironment(gym.Env): """ Custom trading environment for reinforcement learning """ def __init__(self, data: pd.DataFrame, initial_balance: float = 100.0, fee_rate: float = 0.0002, max_steps: int = 1000): super(TradingEnvironment, self).__init__() self.data = data self.initial_balance = initial_balance self.fee_rate = fee_rate self.max_steps = max_steps # Action space: 0 (SELL), 1 (HOLD), 2 (BUY) self.action_space = spaces.Discrete(3) # Observation space: price data, technical indicators, and account state self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(data.shape[1],), # Number of features dtype=np.float32 ) # Initialize state self.reset() def reset(self) -> np.ndarray: """Reset the environment to initial state""" self.current_step = 0 self.balance = self.initial_balance self.position = 0 # 0: no position, 1: long position self.entry_price = 0 self.total_trades = 0 self.winning_trades = 0 self.total_pnl = 0 self.balance_history = [self.initial_balance] self.max_balance = self.initial_balance return self._get_observation() def _get_observation(self) -> np.ndarray: """Get current observation state""" return self.data.iloc[self.current_step].values def _calculate_reward(self, action: int) -> float: """Calculate reward based on action and outcome""" current_price = self.data.iloc[self.current_step]['close'] # If we have an open position if self.position != 0: # Calculate PnL pnl = self.position * (current_price - self.entry_price) / self.entry_price fees = self.fee_rate * 2 # Entry and exit fees # Close position if (action == 0 and self.position > 0) or (action == 2 and self.position < 0): net_pnl = pnl - fees self.total_pnl += net_pnl self.balance *= (1 + net_pnl) self.balance_history.append(self.balance) self.max_balance = max(self.max_balance, self.balance) self.total_trades += 1 if net_pnl > 0: self.winning_trades += 1 # Reward based on PnL reward = net_pnl * 100 # Scale up for better learning # Additional reward for win rate win_rate = self.winning_trades / max(1, self.total_trades) reward += win_rate * 0.1 self.position = 0 return reward # Hold position return pnl * 0.1 # Small reward for holding profitable positions # No position if action == 1: # HOLD return 0 # Open new position if action in [0, 2]: # SELL or BUY self.position = -1 if action == 0 else 1 self.entry_price = current_price return -self.fee_rate # Small penalty for trading return 0 def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]: """Execute one step in the environment""" # Calculate reward reward = self._calculate_reward(action) # Move to next step self.current_step += 1 # Check if episode is done done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1) # Get next observation observation = self._get_observation() # Calculate max drawdown max_drawdown = 0 if len(self.balance_history) > 1: peak = self.balance_history[0] for balance in self.balance_history: peak = max(peak, balance) drawdown = (peak - balance) / peak max_drawdown = max(max_drawdown, drawdown) # Additional info info = { 'balance': self.balance, 'position': self.position, 'total_trades': self.total_trades, 'win_rate': self.winning_trades / max(1, self.total_trades), 'total_pnl': self.total_pnl, 'max_drawdown': max_drawdown } return observation, reward, done, info def render(self, mode='human'): """Render the environment""" if mode == 'human': print(f"Step: {self.current_step}") print(f"Balance: ${self.balance:.2f}") print(f"Position: {self.position}") print(f"Total Trades: {self.total_trades}") print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}") print(f"Total PnL: ${self.total_pnl:.2f}") print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}") print("-" * 50) def _calculate_max_drawdown(self): """Calculate maximum drawdown from balance history""" if len(self.balance_history) <= 1: return 0.0 peak = self.balance_history[0] max_drawdown = 0.0 for balance in self.balance_history: peak = max(peak, balance) drawdown = (peak - balance) / peak max_drawdown = max(max_drawdown, drawdown) return max_drawdown