484 lines
20 KiB
Python
484 lines
20 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
from typing import Dict, Tuple, List, Any, Optional
|
|
import logging
|
|
import gym
|
|
from gym import spaces
|
|
import random
|
|
|
|
# Configure logger
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TradingEnvironment(gym.Env):
|
|
"""
|
|
Trading environment implementing gym interface for reinforcement learning
|
|
|
|
Actions:
|
|
- 0: Buy
|
|
- 1: Sell
|
|
- 2: Hold
|
|
|
|
State:
|
|
- OHLCV data from multiple timeframes
|
|
- Technical indicators
|
|
- Position data
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_interface,
|
|
initial_balance: float = 10000.0,
|
|
transaction_fee: float = 0.0002,
|
|
window_size: int = 20,
|
|
max_position: float = 1.0,
|
|
reward_scaling: float = 1.0,
|
|
):
|
|
"""
|
|
Initialize the trading environment.
|
|
|
|
Args:
|
|
data_interface: DataInterface instance to get market data
|
|
initial_balance: Initial balance in the base currency
|
|
transaction_fee: Fee for each transaction as a fraction of trade value
|
|
window_size: Number of candles in the observation window
|
|
max_position: Maximum position size as a fraction of balance
|
|
reward_scaling: Scale factor for rewards
|
|
"""
|
|
super().__init__()
|
|
|
|
self.data_interface = data_interface
|
|
self.initial_balance = initial_balance
|
|
self.transaction_fee = transaction_fee
|
|
self.window_size = window_size
|
|
self.max_position = max_position
|
|
self.reward_scaling = reward_scaling
|
|
|
|
# Load data for primary timeframe (assuming the first one is primary)
|
|
self.timeframe = self.data_interface.timeframes[0]
|
|
self.reset_data()
|
|
|
|
# Define action and observation spaces
|
|
self.action_space = spaces.Discrete(3) # Buy, Sell, Hold
|
|
|
|
# For observation space, we consider multiple timeframes with OHLCV data
|
|
# and additional features like technical indicators, position info, etc.
|
|
n_timeframes = len(self.data_interface.timeframes)
|
|
n_features = 5 # OHLCV data by default
|
|
|
|
# Add additional features for position, balance, etc.
|
|
additional_features = 3 # position, balance, unrealized_pnl
|
|
|
|
# Calculate total feature dimension
|
|
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
|
|
|
self.observation_space = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
|
|
)
|
|
|
|
# Use tuple for state_shape that EnhancedCNN expects
|
|
self.state_shape = (total_features,)
|
|
|
|
# Initialize state
|
|
self.reset()
|
|
|
|
def reset_data(self):
|
|
"""Reset data and generate a new set of price data for training"""
|
|
# Get data for each timeframe
|
|
self.data = {}
|
|
for tf in self.data_interface.timeframes:
|
|
df = self.data_interface.dataframes[tf]
|
|
if df is not None and not df.empty:
|
|
self.data[tf] = df
|
|
|
|
if not self.data:
|
|
raise ValueError("No data available for training")
|
|
|
|
# Use the primary timeframe for step count
|
|
self.prices = self.data[self.timeframe]['close'].values
|
|
self.timestamps = self.data[self.timeframe].index.values
|
|
self.max_steps = len(self.prices) - self.window_size - 1
|
|
|
|
def reset(self):
|
|
"""Reset the environment to initial state"""
|
|
# Reset trading variables
|
|
self.balance = self.initial_balance
|
|
self.position = 0.0 # No position initially
|
|
self.entry_price = 0.0
|
|
self.total_pnl = 0.0
|
|
self.trades = []
|
|
self.rewards = []
|
|
|
|
# Reset step counter
|
|
self.current_step = self.window_size
|
|
|
|
# Get initial observation
|
|
observation = self._get_observation()
|
|
|
|
return observation
|
|
|
|
def step(self, action):
|
|
"""
|
|
Take a step in the environment.
|
|
|
|
Args:
|
|
action: Action to take (0: Buy, 1: Sell, 2: Hold)
|
|
|
|
Returns:
|
|
tuple: (observation, reward, done, info)
|
|
"""
|
|
# Get current state before taking action
|
|
prev_balance = self.balance
|
|
prev_position = self.position
|
|
prev_price = self.prices[self.current_step]
|
|
|
|
# Take action
|
|
info = {}
|
|
reward = 0
|
|
last_position_info = None
|
|
|
|
# Get current price
|
|
current_price = self.prices[self.current_step]
|
|
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
|
|
|
# Process the action
|
|
if action == 0: # Buy
|
|
if self.position <= 0: # Only buy if not already long
|
|
# Close any existing short position
|
|
if self.position < 0:
|
|
close_pnl, last_position_info = self._close_position(current_price)
|
|
reward += close_pnl * self.reward_scaling
|
|
|
|
# Open new long position
|
|
self._open_position(1.0 * self.max_position, current_price)
|
|
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
|
|
|
elif action == 1: # Sell
|
|
if self.position >= 0: # Only sell if not already short
|
|
# Close any existing long position
|
|
if self.position > 0:
|
|
close_pnl, last_position_info = self._close_position(current_price)
|
|
reward += close_pnl * self.reward_scaling
|
|
|
|
# Open new short position
|
|
self._open_position(-1.0 * self.max_position, current_price)
|
|
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
|
|
|
elif action == 2: # Hold
|
|
# No action, but still calculate unrealized PnL for reward
|
|
pass
|
|
|
|
# Calculate unrealized PnL and add to reward
|
|
if self.position != 0:
|
|
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
|
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
|
|
|
# Apply penalties for holding a position
|
|
if self.position != 0:
|
|
# Small holding fee/interest
|
|
holding_penalty = abs(self.position) * 0.0001 # 0.01% per step
|
|
reward -= holding_penalty * self.reward_scaling
|
|
|
|
# Move to next step
|
|
self.current_step += 1
|
|
|
|
# Get new observation
|
|
observation = self._get_observation()
|
|
|
|
# Check if episode is done
|
|
done = self.current_step >= len(self.prices) - 1
|
|
|
|
# If done, close any remaining positions
|
|
if done and self.position != 0:
|
|
final_pnl, last_position_info = self._close_position(current_price)
|
|
reward += final_pnl * self.reward_scaling
|
|
info['final_pnl'] = final_pnl
|
|
info['final_balance'] = self.balance
|
|
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
|
|
|
|
# Track trade result if position changed or position was closed
|
|
if prev_position != self.position or last_position_info is not None:
|
|
# Calculate realized PnL if position was closed
|
|
realized_pnl = 0
|
|
position_info = {}
|
|
|
|
if last_position_info is not None:
|
|
# Use the position information from closing
|
|
realized_pnl = last_position_info['pnl']
|
|
position_info = last_position_info
|
|
else:
|
|
# Calculate manually based on balance change
|
|
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
|
|
|
|
# Record detailed trade information
|
|
trade_result = {
|
|
'step': self.current_step,
|
|
'timestamp': self.timestamps[self.current_step],
|
|
'action': action,
|
|
'action_name': ['BUY', 'SELL', 'HOLD'][action],
|
|
'price': current_price,
|
|
'position_changed': prev_position != self.position,
|
|
'prev_position': prev_position,
|
|
'new_position': self.position,
|
|
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
|
|
'entry_price': position_info.get('entry_price', self.entry_price),
|
|
'exit_price': position_info.get('exit_price', current_price),
|
|
'realized_pnl': realized_pnl,
|
|
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
|
|
'pnl': realized_pnl, # Total PnL (realized for this step)
|
|
'balance_before': prev_balance,
|
|
'balance_after': self.balance,
|
|
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
|
|
}
|
|
info['trade_result'] = trade_result
|
|
self.trades.append(trade_result)
|
|
|
|
# Log trade details
|
|
logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, "
|
|
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
|
f"Balance: {self.balance:.4f}")
|
|
|
|
# Store reward
|
|
self.rewards.append(reward)
|
|
|
|
# Update info dict with current state
|
|
info.update({
|
|
'step': self.current_step,
|
|
'price': current_price,
|
|
'prev_price': prev_price,
|
|
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'entry_price': self.entry_price,
|
|
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
|
|
'total_trades': len(self.trades),
|
|
'total_pnl': self.total_pnl,
|
|
'return_pct': (self.balance/self.initial_balance-1)*100
|
|
})
|
|
|
|
return observation, reward, done, info
|
|
|
|
def _calculate_unrealized_pnl(self, current_price):
|
|
"""Calculate unrealized PnL for current position"""
|
|
if self.position == 0 or self.entry_price == 0:
|
|
return 0.0
|
|
|
|
if self.position > 0: # Long position
|
|
return self.position * (current_price / self.entry_price - 1.0)
|
|
else: # Short position
|
|
return -self.position * (1.0 - current_price / self.entry_price)
|
|
|
|
def _open_position(self, position_size, price):
|
|
"""Open a new position"""
|
|
self.position = position_size
|
|
self.entry_price = price
|
|
|
|
def _close_position(self, price):
|
|
"""Close the current position and return PnL"""
|
|
pnl = self._calculate_unrealized_pnl(price)
|
|
|
|
# Apply transaction fee
|
|
fee = abs(self.position) * price * self.transaction_fee
|
|
pnl -= fee
|
|
|
|
# Update balance
|
|
self.balance += pnl
|
|
self.total_pnl += pnl
|
|
|
|
# Store position details before resetting
|
|
last_position = {
|
|
'position_size': self.position,
|
|
'entry_price': self.entry_price,
|
|
'exit_price': price,
|
|
'pnl': pnl,
|
|
'fee': fee
|
|
}
|
|
|
|
# Reset position
|
|
self.position = 0.0
|
|
self.entry_price = 0.0
|
|
|
|
# Log position closure
|
|
logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, "
|
|
f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, "
|
|
f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}")
|
|
|
|
return pnl, last_position
|
|
|
|
def _get_observation(self):
|
|
"""
|
|
Get the current observation.
|
|
|
|
Returns:
|
|
np.array: The observation vector
|
|
"""
|
|
observations = []
|
|
|
|
# Get data from each timeframe
|
|
for tf in self.data_interface.timeframes:
|
|
if tf in self.data:
|
|
# Get the window of data for this timeframe
|
|
df = self.data[tf]
|
|
start_idx = self._align_timeframe_index(tf)
|
|
|
|
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
|
|
window = df.iloc[start_idx:start_idx + self.window_size]
|
|
|
|
# Extract OHLCV data
|
|
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
|
|
|
|
# Normalize OHLCV data
|
|
last_close = ohlcv[-1, 3] # Last close price
|
|
ohlcv_normalized = np.zeros_like(ohlcv)
|
|
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
|
|
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
|
|
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
|
|
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
|
|
|
|
# Normalize volume (relative to moving average of volume)
|
|
if 'volume' in window.columns:
|
|
volume_ma = ohlcv[:, 4].mean()
|
|
if volume_ma > 0:
|
|
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
|
|
else:
|
|
ohlcv_normalized[:, 4] = 0.0
|
|
else:
|
|
ohlcv_normalized[:, 4] = 0.0
|
|
|
|
# Flatten and add to observations
|
|
observations.append(ohlcv_normalized.flatten())
|
|
else:
|
|
# Fill with zeros if not enough data
|
|
observations.append(np.zeros(self.window_size * 5))
|
|
|
|
# Add position and balance information
|
|
current_price = self.prices[self.current_step]
|
|
position_info = np.array([
|
|
self.position / self.max_position, # Normalized position (-1 to 1)
|
|
self.balance / self.initial_balance - 1.0, # Normalized balance change
|
|
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
|
|
])
|
|
|
|
observations.append(position_info)
|
|
|
|
# Concatenate all observations
|
|
observation = np.concatenate(observations)
|
|
return observation
|
|
|
|
def _align_timeframe_index(self, timeframe):
|
|
"""
|
|
Align the index of a higher timeframe with the current step in the primary timeframe.
|
|
|
|
Args:
|
|
timeframe: The timeframe to align
|
|
|
|
Returns:
|
|
int: The starting index in the higher timeframe
|
|
"""
|
|
if timeframe == self.timeframe:
|
|
return self.current_step - self.window_size
|
|
|
|
# Get timestamps for current primary timeframe step
|
|
primary_ts = self.timestamps[self.current_step]
|
|
|
|
# Find closest index in the higher timeframe
|
|
higher_ts = self.data[timeframe].index.values
|
|
idx = np.searchsorted(higher_ts, primary_ts)
|
|
|
|
# Adjust to get the starting index
|
|
start_idx = max(0, idx - self.window_size)
|
|
return start_idx
|
|
|
|
def get_last_positions(self, n=5):
|
|
"""
|
|
Get detailed information about the last n positions.
|
|
|
|
Args:
|
|
n: Number of last positions to return
|
|
|
|
Returns:
|
|
list: List of dictionaries containing position details
|
|
"""
|
|
if not self.trades:
|
|
return []
|
|
|
|
# Filter trades to only include those that closed positions
|
|
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
|
|
|
|
positions = []
|
|
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
|
|
|
|
for trade in last_n_trades:
|
|
position_info = {
|
|
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
|
'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]),
|
|
'entry_price': trade.get('entry_price', 0.0),
|
|
'exit_price': trade.get('exit_price', trade['price']),
|
|
'position_size': trade.get('position_size', self.max_position),
|
|
'realized_pnl': trade.get('realized_pnl', 0.0),
|
|
'fee': trade.get('trade_fee', 0.0),
|
|
'pnl': trade.get('pnl', 0.0),
|
|
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
|
|
'balance_before': trade.get('balance_before', 0.0),
|
|
'balance_after': trade.get('balance_after', 0.0),
|
|
'duration': trade.get('duration', 'N/A')
|
|
}
|
|
positions.append(position_info)
|
|
|
|
return positions
|
|
|
|
def render(self, mode='human'):
|
|
"""Render the environment"""
|
|
current_step = self.current_step
|
|
current_price = self.prices[current_step]
|
|
|
|
# Display basic information
|
|
print(f"\nTrading Environment Status:")
|
|
print(f"============================")
|
|
print(f"Step: {current_step}/{len(self.prices)-1}")
|
|
print(f"Current Price: {current_price:.4f}")
|
|
print(f"Current Balance: {self.balance:.4f}")
|
|
print(f"Current Position: {self.position:.4f}")
|
|
|
|
if self.position != 0:
|
|
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
|
|
print(f"Entry Price: {self.entry_price:.4f}")
|
|
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
|
|
|
|
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
|
|
print(f"Total Trades: {len(self.trades)}")
|
|
|
|
if len(self.trades) > 0:
|
|
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
|
|
win_count = len(win_trades)
|
|
# Count trades that closed positions (not just changed them)
|
|
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
|
|
closed_count = len(closed_positions)
|
|
win_rate = win_count / closed_count if closed_count > 0 else 0
|
|
print(f"Positions Closed: {closed_count}")
|
|
print(f"Winning Positions: {win_count}")
|
|
print(f"Win Rate: {win_rate:.2f}")
|
|
|
|
# Display last 5 positions
|
|
print("\nLast 5 Positions:")
|
|
print("================")
|
|
last_positions = self.get_last_positions(5)
|
|
|
|
if not last_positions:
|
|
print("No closed positions yet.")
|
|
|
|
for pos in last_positions:
|
|
print(f"Time: {pos['timestamp']}")
|
|
print(f"Action: {pos['action']}")
|
|
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
|
|
print(f"Size: {pos['position_size']:.4f}")
|
|
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
|
|
print(f"Fee: {pos['fee']:.4f}")
|
|
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
|
|
print("----------------")
|
|
|
|
return
|
|
|
|
def close(self):
|
|
"""Close the environment"""
|
|
pass |