180 lines
6.9 KiB
Python
180 lines
6.9 KiB
Python
"""
|
|
Improved Reward Function for RL Trading Agent
|
|
|
|
This module provides a more sophisticated reward function for the RL trading agent
|
|
that incorporates realistic trading fees, penalties for excessive trading, and
|
|
rewards for successful holding of positions.
|
|
"""
|
|
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
from collections import deque
|
|
|
|
class ImprovedRewardCalculator:
|
|
def __init__(self,
|
|
base_fee_rate=0.001, # 0.1% per transaction
|
|
max_frequency_penalty=0.005, # Maximum 0.5% penalty for frequent trading
|
|
holding_reward_rate=0.0001, # Small reward for holding profitable positions
|
|
risk_adjusted=True): # Use Sharpe ratio for risk adjustment
|
|
|
|
self.base_fee_rate = base_fee_rate
|
|
self.max_frequency_penalty = max_frequency_penalty
|
|
self.holding_reward_rate = holding_reward_rate
|
|
self.risk_adjusted = risk_adjusted
|
|
|
|
# Keep track of recent trades
|
|
self.recent_trades = deque(maxlen=1000)
|
|
self.trade_pnls = deque(maxlen=100) # For risk adjustment
|
|
|
|
def record_trade(self, timestamp=None, action=None, price=None):
|
|
"""Record a trade for frequency tracking"""
|
|
if timestamp is None:
|
|
timestamp = datetime.now()
|
|
|
|
self.recent_trades.append({
|
|
'timestamp': timestamp,
|
|
'action': action,
|
|
'price': price
|
|
})
|
|
|
|
def record_pnl(self, pnl):
|
|
"""Record a PnL result for risk adjustment"""
|
|
self.trade_pnls.append(pnl)
|
|
|
|
def _calculate_frequency_penalty(self):
|
|
"""Calculate penalty for trading too frequently"""
|
|
if len(self.recent_trades) < 2:
|
|
return 0.0
|
|
|
|
# Count trades in the last minute
|
|
now = datetime.now()
|
|
one_minute_ago = now - timedelta(minutes=1)
|
|
trades_last_minute = sum(1 for trade in self.recent_trades
|
|
if trade['timestamp'] > one_minute_ago)
|
|
|
|
# Apply progressive penalty (more severe as frequency increases)
|
|
if trades_last_minute <= 1:
|
|
return 0.0 # No penalty for normal trading rate
|
|
|
|
# Progressive penalty based on trade frequency
|
|
penalty = min(self.max_frequency_penalty,
|
|
self.base_fee_rate * trades_last_minute)
|
|
|
|
return penalty
|
|
|
|
def _calculate_holding_reward(self, position_held_time, price_change_pct):
|
|
"""Calculate reward for holding a position for some time"""
|
|
if position_held_time <= 0 or price_change_pct <= 0:
|
|
return 0.0 # No reward for unprofitable holds
|
|
|
|
# Cap at 100 time units (seconds, minutes, etc.)
|
|
capped_time = min(position_held_time, 100)
|
|
|
|
# Scale reward by both time and price change
|
|
reward = self.holding_reward_rate * capped_time * price_change_pct
|
|
|
|
return reward
|
|
|
|
def _calculate_risk_adjustment(self, reward):
|
|
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
|
|
if len(self.trade_pnls) < 5:
|
|
return reward # Not enough data for adjustment
|
|
|
|
# Calculate mean and standard deviation of returns
|
|
pnl_array = np.array(self.trade_pnls)
|
|
mean_return = np.mean(pnl_array)
|
|
std_return = np.std(pnl_array)
|
|
|
|
if std_return == 0:
|
|
return reward # Avoid division by zero
|
|
|
|
# Simplified Sharpe ratio
|
|
sharpe = mean_return / std_return
|
|
|
|
# Scale reward by Sharpe ratio (normalized to be around 1.0)
|
|
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
|
|
|
|
return reward * adjustment_factor
|
|
|
|
def calculate_reward(self, action, price_change, position_held_time=0,
|
|
volatility=None, is_profitable=False):
|
|
"""
|
|
Calculate the improved reward
|
|
|
|
Args:
|
|
action (int): 0 = Buy, 1 = Sell, 2 = Hold
|
|
price_change (float): Percent price change for the trade
|
|
position_held_time (int): Time position was held (in time units)
|
|
volatility (float, optional): Market volatility measure
|
|
is_profitable (bool): Whether current position is profitable
|
|
|
|
Returns:
|
|
float: Calculated reward value
|
|
"""
|
|
# Calculate trading fee
|
|
fee = self.base_fee_rate
|
|
|
|
# Calculate frequency penalty
|
|
frequency_penalty = self._calculate_frequency_penalty()
|
|
|
|
# Base reward calculation
|
|
if action == 0: # Buy
|
|
# Small penalty for transaction plus frequency penalty
|
|
reward = -fee - frequency_penalty
|
|
|
|
elif action == 1: # Sell
|
|
# Calculate profit percentage minus fees (both entry and exit)
|
|
profit_pct = price_change
|
|
net_profit = profit_pct - (fee * 2)
|
|
|
|
# Scale reward and apply frequency penalty
|
|
reward = net_profit * 10 # Scale reward
|
|
reward -= frequency_penalty
|
|
|
|
# Record PnL for risk adjustment
|
|
self.record_pnl(net_profit)
|
|
|
|
else: # Hold
|
|
# Small reward for holding a profitable position, small cost otherwise
|
|
if is_profitable:
|
|
reward = self._calculate_holding_reward(position_held_time, price_change)
|
|
else:
|
|
reward = -0.0001 # Very small negative reward
|
|
|
|
# Apply risk adjustment if enabled
|
|
if self.risk_adjusted:
|
|
reward = self._calculate_risk_adjustment(reward)
|
|
|
|
# Record this action for future frequency calculations
|
|
self.record_trade(action=action)
|
|
|
|
return reward
|
|
|
|
|
|
# Example usage:
|
|
if __name__ == "__main__":
|
|
# Create calculator instance
|
|
reward_calc = ImprovedRewardCalculator()
|
|
|
|
# Example reward for a buy action
|
|
buy_reward = reward_calc.calculate_reward(action=0, price_change=0)
|
|
print(f"Buy action reward: {buy_reward:.5f}")
|
|
|
|
# Record a trade for frequency tracking
|
|
reward_calc.record_trade(action=0)
|
|
|
|
# Wait a bit and make another trade to test frequency penalty
|
|
import time
|
|
time.sleep(0.1)
|
|
|
|
# Example reward for a sell action with profit
|
|
sell_reward = reward_calc.calculate_reward(action=1, price_change=0.015, position_held_time=60)
|
|
print(f"Sell action reward (with profit): {sell_reward:.5f}")
|
|
|
|
# Example reward for a hold action on profitable position
|
|
hold_reward = reward_calc.calculate_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True)
|
|
print(f"Hold action reward (profitable): {hold_reward:.5f}")
|
|
|
|
# Example reward for a hold action on unprofitable position
|
|
hold_reward_neg = reward_calc.calculate_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False)
|
|
print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}") |