import numpy as np import pandas as pd from typing import Dict, Tuple, List, Any, Optional import logging import gym from gym import spaces import random # Configure logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TradingEnvironment(gym.Env): """ Trading environment implementing gym interface for reinforcement learning 2-Action System: - 0: SELL (or close long position) - 1: BUY (or close short position) Intelligent Position Management: - When neutral: Actions enter positions - When positioned: Actions can close or flip positions - Different thresholds for entry vs exit decisions State: - OHLCV data from multiple timeframes - Technical indicators - Position data and unrealized PnL """ def __init__( self, data_interface, initial_balance: float = 10000.0, transaction_fee: float = 0.0002, window_size: int = 20, max_position: float = 1.0, reward_scaling: float = 1.0, entry_threshold: float = 0.6, # Higher threshold for entering positions exit_threshold: float = 0.3, # Lower threshold for exiting positions ): """ Initialize the trading environment with 2-action system. Args: data_interface: DataInterface instance to get market data initial_balance: Initial balance in the base currency transaction_fee: Fee for each transaction as a fraction of trade value window_size: Number of candles in the observation window max_position: Maximum position size as a fraction of balance reward_scaling: Scale factor for rewards entry_threshold: Confidence threshold for entering new positions exit_threshold: Confidence threshold for exiting positions """ super().__init__() self.data_interface = data_interface self.initial_balance = initial_balance self.transaction_fee = transaction_fee self.window_size = window_size self.max_position = max_position self.reward_scaling = reward_scaling self.entry_threshold = entry_threshold self.exit_threshold = exit_threshold # Load data for primary timeframe (assuming the first one is primary) self.timeframe = self.data_interface.timeframes[0] self.reset_data() # Define action and observation spaces for 2-action system self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY # For observation space, we consider multiple timeframes with OHLCV data # and additional features like technical indicators, position info, etc. n_timeframes = len(self.data_interface.timeframes) n_features = 5 # OHLCV data by default # Add additional features for position, balance, unrealized_pnl, etc. additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration # Calculate total feature dimension total_features = (n_timeframes * n_features * self.window_size) + additional_features self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32 ) # Use tuple for state_shape that EnhancedCNN expects self.state_shape = (total_features,) # Position tracking for 2-action system self.position = 0.0 # -1 (short), 0 (neutral), 1 (long) self.entry_price = 0.0 # Price at which position was entered self.entry_step = 0 # Step at which position was entered # Initialize state self.reset() def reset_data(self): """Reset data and generate a new set of price data for training""" # Get data for each timeframe self.data = {} for tf in self.data_interface.timeframes: df = self.data_interface.dataframes[tf] if df is not None and not df.empty: self.data[tf] = df if not self.data: raise ValueError("No data available for training") # Use the primary timeframe for step count self.prices = self.data[self.timeframe]['close'].values self.timestamps = self.data[self.timeframe].index.values self.max_steps = len(self.prices) - self.window_size - 1 def reset(self): """Reset the environment to initial state""" # Reset trading variables self.balance = self.initial_balance self.trades = [] self.rewards = [] # Reset step counter self.current_step = self.window_size # Get initial observation observation = self._get_observation() return observation def step(self, action): """ Take a step in the environment using 2-action system with intelligent position management. Args: action: Action to take (0: SELL, 1: BUY) Returns: tuple: (observation, reward, done, info) """ # Get current state before taking action prev_balance = self.balance prev_position = self.position prev_price = self.prices[self.current_step] # Take action with intelligent position management info = {} reward = 0 last_position_info = None # Get current price current_price = self.prices[self.current_step] next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price # Implement 2-action system with position management if action == 0: # SELL action if self.position == 0: # No position - enter short self._open_position(-1.0 * self.max_position, current_price) logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}") reward = -self.transaction_fee # Entry cost elif self.position > 0: # Long position - close it close_pnl, last_position_info = self._close_position(current_price) reward += close_pnl * self.reward_scaling logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}") elif self.position < 0: # Already short - potentially flip to long if very strong signal # For now, just hold the short position (no action) pass elif action == 1: # BUY action if self.position == 0: # No position - enter long self._open_position(1.0 * self.max_position, current_price) logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}") reward = -self.transaction_fee # Entry cost elif self.position < 0: # Short position - close it close_pnl, last_position_info = self._close_position(current_price) reward += close_pnl * self.reward_scaling logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}") elif self.position > 0: # Already long - potentially flip to short if very strong signal # For now, just hold the long position (no action) pass # Calculate unrealized PnL and add to reward if holding position if self.position != 0: unrealized_pnl = self._calculate_unrealized_pnl(next_price) reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL # Apply time-based holding penalty to encourage decisive actions position_duration = self.current_step - self.entry_step holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty reward -= holding_penalty # Reward staying neutral when uncertain (no clear setup) else: reward += 0.0001 # Small reward for not trading without clear signals # Move to next step self.current_step += 1 # Get new observation observation = self._get_observation() # Check if episode is done done = self.current_step >= len(self.prices) - 1 # If done, close any remaining positions if done and self.position != 0: final_pnl, last_position_info = self._close_position(current_price) reward += final_pnl * self.reward_scaling info['final_pnl'] = final_pnl info['final_balance'] = self.balance logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%") # Track trade result if position changed or position was closed if prev_position != self.position or last_position_info is not None: # Calculate realized PnL if position was closed realized_pnl = 0 position_info = {} if last_position_info is not None: # Use the position information from closing realized_pnl = last_position_info['pnl'] position_info = last_position_info else: # Calculate manually based on balance change realized_pnl = self.balance - prev_balance if prev_position != 0 else 0 # Record detailed trade information trade_result = { 'step': self.current_step, 'timestamp': self.timestamps[self.current_step], 'action': action, 'action_name': ['SELL', 'BUY'][action], 'price': current_price, 'position_changed': prev_position != self.position, 'prev_position': prev_position, 'new_position': self.position, 'position_size': abs(self.position) if self.position != 0 else abs(prev_position), 'entry_price': position_info.get('entry_price', self.entry_price), 'exit_price': position_info.get('exit_price', current_price), 'realized_pnl': realized_pnl, 'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0, 'pnl': realized_pnl, # Total PnL (realized for this step) 'balance_before': prev_balance, 'balance_after': self.balance, 'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee) } info['trade_result'] = trade_result self.trades.append(trade_result) # Log trade details logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, " f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, " f"Balance: {self.balance:.4f}") # Store reward self.rewards.append(reward) # Update info dict with current state info.update({ 'step': self.current_step, 'price': current_price, 'prev_price': prev_price, 'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0, 'balance': self.balance, 'position': self.position, 'entry_price': self.entry_price, 'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0, 'total_trades': len(self.trades), 'total_pnl': self.total_pnl, 'return_pct': (self.balance/self.initial_balance-1)*100 }) return observation, reward, done, info def _calculate_unrealized_pnl(self, current_price): """Calculate unrealized PnL for current position""" if self.position == 0 or self.entry_price == 0: return 0.0 if self.position > 0: # Long position return self.position * (current_price / self.entry_price - 1.0) else: # Short position return -self.position * (1.0 - current_price / self.entry_price) def _open_position(self, position_size: float, entry_price: float): """Open a new position""" self.position = position_size self.entry_price = entry_price self.entry_step = self.current_step # Calculate position value position_value = abs(position_size) * entry_price # Apply transaction fee fee = position_value * self.transaction_fee self.balance -= fee logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}") def _close_position(self, exit_price: float) -> Tuple[float, Dict]: """Close current position and return PnL""" if self.position == 0: return 0.0, {} # Calculate PnL if self.position > 0: # Long position pnl = (exit_price - self.entry_price) / self.entry_price else: # Short position pnl = (self.entry_price - exit_price) / self.entry_price # Apply transaction fees (entry + exit) position_value = abs(self.position) * exit_price exit_fee = position_value * self.transaction_fee total_fees = exit_fee # Entry fee already applied when opening # Net PnL after fees net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price)) # Update balance self.balance *= (1 + net_pnl) self.total_pnl += net_pnl # Track trade position_info = { 'position_size': self.position, 'entry_price': self.entry_price, 'exit_price': exit_price, 'pnl': net_pnl, 'duration': self.current_step - self.entry_step, 'entry_step': self.entry_step, 'exit_step': self.current_step } self.trades.append(position_info) # Update trade statistics if net_pnl > 0: self.winning_trades += 1 else: self.losing_trades += 1 logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps") # Reset position self.position = 0.0 self.entry_price = 0.0 self.entry_step = 0 return net_pnl, position_info def _get_observation(self): """ Get the current observation. Returns: np.array: The observation vector """ observations = [] # Get data from each timeframe for tf in self.data_interface.timeframes: if tf in self.data: # Get the window of data for this timeframe df = self.data[tf] start_idx = self._align_timeframe_index(tf) if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df): window = df.iloc[start_idx:start_idx + self.window_size] # Extract OHLCV data ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values # Normalize OHLCV data last_close = ohlcv[-1, 3] # Last close price ohlcv_normalized = np.zeros_like(ohlcv) ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close # Normalize volume (relative to moving average of volume) if 'volume' in window.columns: volume_ma = ohlcv[:, 4].mean() if volume_ma > 0: ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0 else: ohlcv_normalized[:, 4] = 0.0 else: ohlcv_normalized[:, 4] = 0.0 # Flatten and add to observations observations.append(ohlcv_normalized.flatten()) else: # Fill with zeros if not enough data observations.append(np.zeros(self.window_size * 5)) # Add position and balance information current_price = self.prices[self.current_step] position_info = np.array([ self.position / self.max_position, # Normalized position (-1 to 1) self.balance / self.initial_balance - 1.0, # Normalized balance change self._calculate_unrealized_pnl(current_price) # Unrealized PnL ]) observations.append(position_info) # Concatenate all observations observation = np.concatenate(observations) return observation def _align_timeframe_index(self, timeframe): """ Align the index of a higher timeframe with the current step in the primary timeframe. Args: timeframe: The timeframe to align Returns: int: The starting index in the higher timeframe """ if timeframe == self.timeframe: return self.current_step - self.window_size # Get timestamps for current primary timeframe step primary_ts = self.timestamps[self.current_step] # Find closest index in the higher timeframe higher_ts = self.data[timeframe].index.values idx = np.searchsorted(higher_ts, primary_ts) # Adjust to get the starting index start_idx = max(0, idx - self.window_size) return start_idx def get_last_positions(self, n=5): """ Get detailed information about the last n positions. Args: n: Number of last positions to return Returns: list: List of dictionaries containing position details """ if not self.trades: return [] # Filter trades to only include those that closed positions position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)] positions = [] last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades for trade in last_n_trades: position_info = { 'timestamp': trade.get('timestamp', self.timestamps[trade['step']]), 'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]), 'entry_price': trade.get('entry_price', 0.0), 'exit_price': trade.get('exit_price', trade['price']), 'position_size': trade.get('position_size', self.max_position), 'realized_pnl': trade.get('realized_pnl', 0.0), 'fee': trade.get('trade_fee', 0.0), 'pnl': trade.get('pnl', 0.0), 'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100, 'balance_before': trade.get('balance_before', 0.0), 'balance_after': trade.get('balance_after', 0.0), 'duration': trade.get('duration', 'N/A') } positions.append(position_info) return positions def render(self, mode='human'): """Render the environment""" current_step = self.current_step current_price = self.prices[current_step] # Display basic information print(f"\nTrading Environment Status:") print(f"============================") print(f"Step: {current_step}/{len(self.prices)-1}") print(f"Current Price: {current_price:.4f}") print(f"Current Balance: {self.balance:.4f}") print(f"Current Position: {self.position:.4f}") if self.position != 0: unrealized_pnl = self._calculate_unrealized_pnl(current_price) print(f"Entry Price: {self.entry_price:.4f}") print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)") print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)") print(f"Total Trades: {len(self.trades)}") if len(self.trades) > 0: win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0] win_count = len(win_trades) # Count trades that closed positions (not just changed them) closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0] closed_count = len(closed_positions) win_rate = win_count / closed_count if closed_count > 0 else 0 print(f"Positions Closed: {closed_count}") print(f"Winning Positions: {win_count}") print(f"Win Rate: {win_rate:.2f}") # Display last 5 positions print("\nLast 5 Positions:") print("================") last_positions = self.get_last_positions(5) if not last_positions: print("No closed positions yet.") for pos in last_positions: print(f"Time: {pos['timestamp']}") print(f"Action: {pos['action']}") print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}") print(f"Size: {pos['position_size']:.4f}") print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)") print(f"Fee: {pos['fee']:.4f}") print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}") print("----------------") return def close(self): """Close the environment""" pass