added leverage slider
This commit is contained in:
@ -14,15 +14,19 @@ class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Trading environment implementing gym interface for reinforcement learning
|
||||
|
||||
Actions:
|
||||
- 0: Buy
|
||||
- 1: Sell
|
||||
- 2: Hold
|
||||
2-Action System:
|
||||
- 0: SELL (or close long position)
|
||||
- 1: BUY (or close short position)
|
||||
|
||||
Intelligent Position Management:
|
||||
- When neutral: Actions enter positions
|
||||
- When positioned: Actions can close or flip positions
|
||||
- Different thresholds for entry vs exit decisions
|
||||
|
||||
State:
|
||||
- OHLCV data from multiple timeframes
|
||||
- Technical indicators
|
||||
- Position data
|
||||
- Position data and unrealized PnL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -33,9 +37,11 @@ class TradingEnvironment(gym.Env):
|
||||
window_size: int = 20,
|
||||
max_position: float = 1.0,
|
||||
reward_scaling: float = 1.0,
|
||||
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
||||
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
||||
):
|
||||
"""
|
||||
Initialize the trading environment.
|
||||
Initialize the trading environment with 2-action system.
|
||||
|
||||
Args:
|
||||
data_interface: DataInterface instance to get market data
|
||||
@ -44,6 +50,8 @@ class TradingEnvironment(gym.Env):
|
||||
window_size: Number of candles in the observation window
|
||||
max_position: Maximum position size as a fraction of balance
|
||||
reward_scaling: Scale factor for rewards
|
||||
entry_threshold: Confidence threshold for entering new positions
|
||||
exit_threshold: Confidence threshold for exiting positions
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -53,21 +61,23 @@ class TradingEnvironment(gym.Env):
|
||||
self.window_size = window_size
|
||||
self.max_position = max_position
|
||||
self.reward_scaling = reward_scaling
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
|
||||
# Load data for primary timeframe (assuming the first one is primary)
|
||||
self.timeframe = self.data_interface.timeframes[0]
|
||||
self.reset_data()
|
||||
|
||||
# Define action and observation spaces
|
||||
self.action_space = spaces.Discrete(3) # Buy, Sell, Hold
|
||||
# Define action and observation spaces for 2-action system
|
||||
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
||||
|
||||
# For observation space, we consider multiple timeframes with OHLCV data
|
||||
# and additional features like technical indicators, position info, etc.
|
||||
n_timeframes = len(self.data_interface.timeframes)
|
||||
n_features = 5 # OHLCV data by default
|
||||
|
||||
# Add additional features for position, balance, etc.
|
||||
additional_features = 3 # position, balance, unrealized_pnl
|
||||
# Add additional features for position, balance, unrealized_pnl, etc.
|
||||
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
||||
|
||||
# Calculate total feature dimension
|
||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||
@ -79,6 +89,11 @@ class TradingEnvironment(gym.Env):
|
||||
# Use tuple for state_shape that EnhancedCNN expects
|
||||
self.state_shape = (total_features,)
|
||||
|
||||
# Position tracking for 2-action system
|
||||
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.entry_price = 0.0 # Price at which position was entered
|
||||
self.entry_step = 0 # Step at which position was entered
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
@ -103,9 +118,6 @@ class TradingEnvironment(gym.Env):
|
||||
"""Reset the environment to initial state"""
|
||||
# Reset trading variables
|
||||
self.balance = self.initial_balance
|
||||
self.position = 0.0 # No position initially
|
||||
self.entry_price = 0.0
|
||||
self.total_pnl = 0.0
|
||||
self.trades = []
|
||||
self.rewards = []
|
||||
|
||||
@ -119,10 +131,10 @@ class TradingEnvironment(gym.Env):
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment.
|
||||
Take a step in the environment using 2-action system with intelligent position management.
|
||||
|
||||
Args:
|
||||
action: Action to take (0: Buy, 1: Sell, 2: Hold)
|
||||
action: Action to take (0: SELL, 1: BUY)
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
@ -132,7 +144,7 @@ class TradingEnvironment(gym.Env):
|
||||
prev_position = self.position
|
||||
prev_price = self.prices[self.current_step]
|
||||
|
||||
# Take action
|
||||
# Take action with intelligent position management
|
||||
info = {}
|
||||
reward = 0
|
||||
last_position_info = None
|
||||
@ -141,43 +153,50 @@ class TradingEnvironment(gym.Env):
|
||||
current_price = self.prices[self.current_step]
|
||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||
|
||||
# Process the action
|
||||
if action == 0: # Buy
|
||||
if self.position <= 0: # Only buy if not already long
|
||||
# Close any existing short position
|
||||
if self.position < 0:
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
|
||||
# Open new long position
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
|
||||
elif action == 1: # Sell
|
||||
if self.position >= 0: # Only sell if not already short
|
||||
# Close any existing long position
|
||||
if self.position > 0:
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
|
||||
# Open new short position
|
||||
# Implement 2-action system with position management
|
||||
if action == 0: # SELL action
|
||||
if self.position == 0: # No position - enter short
|
||||
self._open_position(-1.0 * self.max_position, current_price)
|
||||
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position > 0: # Long position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
||||
# For now, just hold the short position (no action)
|
||||
pass
|
||||
|
||||
elif action == 2: # Hold
|
||||
# No action, but still calculate unrealized PnL for reward
|
||||
pass
|
||||
elif action == 1: # BUY action
|
||||
if self.position == 0: # No position - enter long
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position < 0: # Short position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
||||
# For now, just hold the long position (no action)
|
||||
pass
|
||||
|
||||
# Calculate unrealized PnL and add to reward
|
||||
# Calculate unrealized PnL and add to reward if holding position
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||
|
||||
# Apply time-based holding penalty to encourage decisive actions
|
||||
position_duration = self.current_step - self.entry_step
|
||||
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
||||
reward -= holding_penalty
|
||||
|
||||
# Apply penalties for holding a position
|
||||
if self.position != 0:
|
||||
# Small holding fee/interest
|
||||
holding_penalty = abs(self.position) * 0.0001 # 0.01% per step
|
||||
reward -= holding_penalty * self.reward_scaling
|
||||
# Reward staying neutral when uncertain (no clear setup)
|
||||
else:
|
||||
reward += 0.0001 # Small reward for not trading without clear signals
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
@ -215,7 +234,7 @@ class TradingEnvironment(gym.Env):
|
||||
'step': self.current_step,
|
||||
'timestamp': self.timestamps[self.current_step],
|
||||
'action': action,
|
||||
'action_name': ['BUY', 'SELL', 'HOLD'][action],
|
||||
'action_name': ['SELL', 'BUY'][action],
|
||||
'price': current_price,
|
||||
'position_changed': prev_position != self.position,
|
||||
'prev_position': prev_position,
|
||||
@ -234,7 +253,7 @@ class TradingEnvironment(gym.Env):
|
||||
self.trades.append(trade_result)
|
||||
|
||||
# Log trade details
|
||||
logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, "
|
||||
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||
f"Balance: {self.balance:.4f}")
|
||||
|
||||
@ -268,42 +287,71 @@ class TradingEnvironment(gym.Env):
|
||||
else: # Short position
|
||||
return -self.position * (1.0 - current_price / self.entry_price)
|
||||
|
||||
def _open_position(self, position_size, price):
|
||||
def _open_position(self, position_size: float, entry_price: float):
|
||||
"""Open a new position"""
|
||||
self.position = position_size
|
||||
self.entry_price = price
|
||||
self.entry_price = entry_price
|
||||
self.entry_step = self.current_step
|
||||
|
||||
def _close_position(self, price):
|
||||
"""Close the current position and return PnL"""
|
||||
pnl = self._calculate_unrealized_pnl(price)
|
||||
# Calculate position value
|
||||
position_value = abs(position_size) * entry_price
|
||||
|
||||
# Apply transaction fee
|
||||
fee = abs(self.position) * price * self.transaction_fee
|
||||
pnl -= fee
|
||||
fee = position_value * self.transaction_fee
|
||||
self.balance -= fee
|
||||
|
||||
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
||||
|
||||
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
||||
"""Close current position and return PnL"""
|
||||
if self.position == 0:
|
||||
return 0.0, {}
|
||||
|
||||
# Calculate PnL
|
||||
if self.position > 0: # Long position
|
||||
pnl = (exit_price - self.entry_price) / self.entry_price
|
||||
else: # Short position
|
||||
pnl = (self.entry_price - exit_price) / self.entry_price
|
||||
|
||||
# Apply transaction fees (entry + exit)
|
||||
position_value = abs(self.position) * exit_price
|
||||
exit_fee = position_value * self.transaction_fee
|
||||
total_fees = exit_fee # Entry fee already applied when opening
|
||||
|
||||
# Net PnL after fees
|
||||
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl
|
||||
self.total_pnl += pnl
|
||||
self.balance *= (1 + net_pnl)
|
||||
self.total_pnl += net_pnl
|
||||
|
||||
# Store position details before resetting
|
||||
last_position = {
|
||||
# Track trade
|
||||
position_info = {
|
||||
'position_size': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'exit_price': price,
|
||||
'pnl': pnl,
|
||||
'fee': fee
|
||||
'exit_price': exit_price,
|
||||
'pnl': net_pnl,
|
||||
'duration': self.current_step - self.entry_step,
|
||||
'entry_step': self.entry_step,
|
||||
'exit_step': self.current_step
|
||||
}
|
||||
|
||||
self.trades.append(position_info)
|
||||
|
||||
# Update trade statistics
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
||||
|
||||
# Reset position
|
||||
self.position = 0.0
|
||||
self.entry_price = 0.0
|
||||
self.entry_step = 0
|
||||
|
||||
# Log position closure
|
||||
logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, "
|
||||
f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, "
|
||||
f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}")
|
||||
|
||||
return pnl, last_position
|
||||
return net_pnl, position_info
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
@ -411,7 +459,7 @@ class TradingEnvironment(gym.Env):
|
||||
for trade in last_n_trades:
|
||||
position_info = {
|
||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||
'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]),
|
||||
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
||||
'entry_price': trade.get('entry_price', 0.0),
|
||||
'exit_price': trade.get('exit_price', trade['price']),
|
||||
'position_size': trade.get('position_size', self.max_position),
|
||||
|
Reference in New Issue
Block a user