import numpy as np import pandas as pd from typing import Dict, Tuple, List, Any, Optional import logging import gym from gym import spaces import random # Configure logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TradingEnvironment(gym.Env): """ Trading environment implementing gym interface for reinforcement learning Actions: - 0: Buy - 1: Sell - 2: Hold State: - OHLCV data from multiple timeframes - Technical indicators - Position data """ def __init__( self, data_interface, initial_balance: float = 10000.0, transaction_fee: float = 0.0002, window_size: int = 20, max_position: float = 1.0, reward_scaling: float = 1.0, ): """ Initialize the trading environment. Args: data_interface: DataInterface instance to get market data initial_balance: Initial balance in the base currency transaction_fee: Fee for each transaction as a fraction of trade value window_size: Number of candles in the observation window max_position: Maximum position size as a fraction of balance reward_scaling: Scale factor for rewards """ super().__init__() self.data_interface = data_interface self.initial_balance = initial_balance self.transaction_fee = transaction_fee self.window_size = window_size self.max_position = max_position self.reward_scaling = reward_scaling # Load data for primary timeframe (assuming the first one is primary) self.timeframe = self.data_interface.timeframes[0] self.reset_data() # Define action and observation spaces self.action_space = spaces.Discrete(3) # Buy, Sell, Hold # For observation space, we consider multiple timeframes with OHLCV data # and additional features like technical indicators, position info, etc. n_timeframes = len(self.data_interface.timeframes) n_features = 5 # OHLCV data by default # Add additional features for position, balance, etc. additional_features = 3 # position, balance, unrealized_pnl # Calculate total feature dimension total_features = (n_timeframes * n_features * self.window_size) + additional_features self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32 ) # Use tuple for state_shape that EnhancedCNN expects self.state_shape = (total_features,) # Initialize state self.reset() def reset_data(self): """Reset data and generate a new set of price data for training""" # Get data for each timeframe self.data = {} for tf in self.data_interface.timeframes: df = self.data_interface.dataframes[tf] if df is not None and not df.empty: self.data[tf] = df if not self.data: raise ValueError("No data available for training") # Use the primary timeframe for step count self.prices = self.data[self.timeframe]['close'].values self.timestamps = self.data[self.timeframe].index.values self.max_steps = len(self.prices) - self.window_size - 1 def reset(self): """Reset the environment to initial state""" # Reset trading variables self.balance = self.initial_balance self.position = 0.0 # No position initially self.entry_price = 0.0 self.total_pnl = 0.0 self.trades = [] self.rewards = [] # Reset step counter self.current_step = self.window_size # Get initial observation observation = self._get_observation() return observation def step(self, action): """ Take a step in the environment. Args: action: Action to take (0: Buy, 1: Sell, 2: Hold) Returns: tuple: (observation, reward, done, info) """ # Get current state before taking action prev_balance = self.balance prev_position = self.position prev_price = self.prices[self.current_step] # Take action info = {} reward = 0 last_position_info = None # Get current price current_price = self.prices[self.current_step] next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price # Process the action if action == 0: # Buy if self.position <= 0: # Only buy if not already long # Close any existing short position if self.position < 0: close_pnl, last_position_info = self._close_position(current_price) reward += close_pnl * self.reward_scaling # Open new long position self._open_position(1.0 * self.max_position, current_price) logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}") elif action == 1: # Sell if self.position >= 0: # Only sell if not already short # Close any existing long position if self.position > 0: close_pnl, last_position_info = self._close_position(current_price) reward += close_pnl * self.reward_scaling # Open new short position self._open_position(-1.0 * self.max_position, current_price) logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}") elif action == 2: # Hold # No action, but still calculate unrealized PnL for reward pass # Calculate unrealized PnL and add to reward if self.position != 0: unrealized_pnl = self._calculate_unrealized_pnl(next_price) reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL # Apply penalties for holding a position if self.position != 0: # Small holding fee/interest holding_penalty = abs(self.position) * 0.0001 # 0.01% per step reward -= holding_penalty * self.reward_scaling # Move to next step self.current_step += 1 # Get new observation observation = self._get_observation() # Check if episode is done done = self.current_step >= len(self.prices) - 1 # If done, close any remaining positions if done and self.position != 0: final_pnl, last_position_info = self._close_position(current_price) reward += final_pnl * self.reward_scaling info['final_pnl'] = final_pnl info['final_balance'] = self.balance logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%") # Track trade result if position changed or position was closed if prev_position != self.position or last_position_info is not None: # Calculate realized PnL if position was closed realized_pnl = 0 position_info = {} if last_position_info is not None: # Use the position information from closing realized_pnl = last_position_info['pnl'] position_info = last_position_info else: # Calculate manually based on balance change realized_pnl = self.balance - prev_balance if prev_position != 0 else 0 # Record detailed trade information trade_result = { 'step': self.current_step, 'timestamp': self.timestamps[self.current_step], 'action': action, 'action_name': ['BUY', 'SELL', 'HOLD'][action], 'price': current_price, 'position_changed': prev_position != self.position, 'prev_position': prev_position, 'new_position': self.position, 'position_size': abs(self.position) if self.position != 0 else abs(prev_position), 'entry_price': position_info.get('entry_price', self.entry_price), 'exit_price': position_info.get('exit_price', current_price), 'realized_pnl': realized_pnl, 'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0, 'pnl': realized_pnl, # Total PnL (realized for this step) 'balance_before': prev_balance, 'balance_after': self.balance, 'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee) } info['trade_result'] = trade_result self.trades.append(trade_result) # Log trade details logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, " f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, " f"Balance: {self.balance:.4f}") # Store reward self.rewards.append(reward) # Update info dict with current state info.update({ 'step': self.current_step, 'price': current_price, 'prev_price': prev_price, 'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0, 'balance': self.balance, 'position': self.position, 'entry_price': self.entry_price, 'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0, 'total_trades': len(self.trades), 'total_pnl': self.total_pnl, 'return_pct': (self.balance/self.initial_balance-1)*100 }) return observation, reward, done, info def _calculate_unrealized_pnl(self, current_price): """Calculate unrealized PnL for current position""" if self.position == 0 or self.entry_price == 0: return 0.0 if self.position > 0: # Long position return self.position * (current_price / self.entry_price - 1.0) else: # Short position return -self.position * (1.0 - current_price / self.entry_price) def _open_position(self, position_size, price): """Open a new position""" self.position = position_size self.entry_price = price def _close_position(self, price): """Close the current position and return PnL""" pnl = self._calculate_unrealized_pnl(price) # Apply transaction fee fee = abs(self.position) * price * self.transaction_fee pnl -= fee # Update balance self.balance += pnl self.total_pnl += pnl # Store position details before resetting last_position = { 'position_size': self.position, 'entry_price': self.entry_price, 'exit_price': price, 'pnl': pnl, 'fee': fee } # Reset position self.position = 0.0 self.entry_price = 0.0 # Log position closure logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, " f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, " f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}") return pnl, last_position def _get_observation(self): """ Get the current observation. Returns: np.array: The observation vector """ observations = [] # Get data from each timeframe for tf in self.data_interface.timeframes: if tf in self.data: # Get the window of data for this timeframe df = self.data[tf] start_idx = self._align_timeframe_index(tf) if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df): window = df.iloc[start_idx:start_idx + self.window_size] # Extract OHLCV data ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values # Normalize OHLCV data last_close = ohlcv[-1, 3] # Last close price ohlcv_normalized = np.zeros_like(ohlcv) ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close # Normalize volume (relative to moving average of volume) if 'volume' in window.columns: volume_ma = ohlcv[:, 4].mean() if volume_ma > 0: ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0 else: ohlcv_normalized[:, 4] = 0.0 else: ohlcv_normalized[:, 4] = 0.0 # Flatten and add to observations observations.append(ohlcv_normalized.flatten()) else: # Fill with zeros if not enough data observations.append(np.zeros(self.window_size * 5)) # Add position and balance information current_price = self.prices[self.current_step] position_info = np.array([ self.position / self.max_position, # Normalized position (-1 to 1) self.balance / self.initial_balance - 1.0, # Normalized balance change self._calculate_unrealized_pnl(current_price) # Unrealized PnL ]) observations.append(position_info) # Concatenate all observations observation = np.concatenate(observations) return observation def _align_timeframe_index(self, timeframe): """ Align the index of a higher timeframe with the current step in the primary timeframe. Args: timeframe: The timeframe to align Returns: int: The starting index in the higher timeframe """ if timeframe == self.timeframe: return self.current_step - self.window_size # Get timestamps for current primary timeframe step primary_ts = self.timestamps[self.current_step] # Find closest index in the higher timeframe higher_ts = self.data[timeframe].index.values idx = np.searchsorted(higher_ts, primary_ts) # Adjust to get the starting index start_idx = max(0, idx - self.window_size) return start_idx def get_last_positions(self, n=5): """ Get detailed information about the last n positions. Args: n: Number of last positions to return Returns: list: List of dictionaries containing position details """ if not self.trades: return [] # Filter trades to only include those that closed positions position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)] positions = [] last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades for trade in last_n_trades: position_info = { 'timestamp': trade.get('timestamp', self.timestamps[trade['step']]), 'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]), 'entry_price': trade.get('entry_price', 0.0), 'exit_price': trade.get('exit_price', trade['price']), 'position_size': trade.get('position_size', self.max_position), 'realized_pnl': trade.get('realized_pnl', 0.0), 'fee': trade.get('trade_fee', 0.0), 'pnl': trade.get('pnl', 0.0), 'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100, 'balance_before': trade.get('balance_before', 0.0), 'balance_after': trade.get('balance_after', 0.0), 'duration': trade.get('duration', 'N/A') } positions.append(position_info) return positions def render(self, mode='human'): """Render the environment""" current_step = self.current_step current_price = self.prices[current_step] # Display basic information print(f"\nTrading Environment Status:") print(f"============================") print(f"Step: {current_step}/{len(self.prices)-1}") print(f"Current Price: {current_price:.4f}") print(f"Current Balance: {self.balance:.4f}") print(f"Current Position: {self.position:.4f}") if self.position != 0: unrealized_pnl = self._calculate_unrealized_pnl(current_price) print(f"Entry Price: {self.entry_price:.4f}") print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)") print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)") print(f"Total Trades: {len(self.trades)}") if len(self.trades) > 0: win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0] win_count = len(win_trades) # Count trades that closed positions (not just changed them) closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0] closed_count = len(closed_positions) win_rate = win_count / closed_count if closed_count > 0 else 0 print(f"Positions Closed: {closed_count}") print(f"Winning Positions: {win_count}") print(f"Win Rate: {win_rate:.2f}") # Display last 5 positions print("\nLast 5 Positions:") print("================") last_positions = self.get_last_positions(5) if not last_positions: print("No closed positions yet.") for pos in last_positions: print(f"Time: {pos['timestamp']}") print(f"Action: {pos['action']}") print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}") print(f"Size: {pos['position_size']:.4f}") print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)") print(f"Fee: {pos['fee']:.4f}") print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}") print("----------------") return def close(self): """Close the environment""" pass