charts more or less OK

This commit is contained in:
Dobromir Popov 2025-04-01 21:37:08 +03:00
parent b0a57c5330
commit 938eef8bc9
3 changed files with 465 additions and 226 deletions

View File

@ -34,7 +34,7 @@ class RLTradingEnvironment(gym.Env):
Reinforcement Learning environment for trading with technical indicators Reinforcement Learning environment for trading with technical indicators
from multiple timeframes from multiple timeframes
""" """
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001): def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__() super().__init__()
# Initialize attributes before parent class # Initialize attributes before parent class
@ -50,7 +50,8 @@ class RLTradingEnvironment(gym.Env):
# Trading parameters # Trading parameters
self.initial_balance = 1.0 self.initial_balance = 1.0
self.trading_fee = trading_fee self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces # Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
@ -76,6 +77,7 @@ class RLTradingEnvironment(gym.Env):
self.wins = 0 self.wins = 0
self.losses = 0 self.losses = 0
self.trade_history = [] self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation # Get initial observation
observation = self._get_observation() observation = self._get_observation()
@ -150,24 +152,40 @@ class RLTradingEnvironment(gym.Env):
done = False done = False
profit_pct = None # Initialize profit_pct variable profit_pct = None # Initialize profit_pct variable
# Check if enough time has passed since last trade
trade_interval = self.current_step - self.last_trade_step
trade_interval_penalty = 0
# Execute action # Execute action
if action == 0: # BUY if action == 0: # BUY
if self.position == 0: # Only buy if not already in position if self.position == 0: # Only buy if not already in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
self.position = self.balance * (1 - self.trading_fee) self.position = self.balance * (1 - self.trading_fee)
self.balance = 0 self.balance = 0
self.trades += 1 self.trades += 1
reward = 0 # Neutral reward for entering position reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
self.trade_entry_price = current_price self.trade_entry_price = current_price
self.last_trade_step = self.current_step
elif action == 1: # SELL elif action == 1: # SELL
if self.position > 0: # Only sell if in position if self.position > 0: # Only sell if in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
# Calculate position value at current price # Calculate position value at current price
position_value = self.position * (1 + price_change) position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee) self.balance = position_value * (1 - self.trading_fee)
# Calculate profit/loss from trade # Calculate profit/loss from trade
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
reward = profit_pct * 10 # Scale reward by profit percentage # Scale reward by profit percentage and apply trade interval penalty
reward = (profit_pct * 10) + trade_interval_penalty
# Update win/loss count # Update win/loss count
if profit_pct > 0: if profit_pct > 0:
@ -179,11 +197,13 @@ class RLTradingEnvironment(gym.Env):
self.trade_history.append({ self.trade_history.append({
'entry_price': self.trade_entry_price, 'entry_price': self.trade_entry_price,
'exit_price': next_price, 'exit_price': next_price,
'profit_pct': profit_pct 'profit_pct': profit_pct,
'trade_interval': trade_interval
}) })
# Reset position # Reset position and update last trade step
self.position = 0 self.position = 0
self.last_trade_step = self.current_step
# else: (action == 2 - HOLD) - no position change # else: (action == 2 - HOLD) - no position change

View File

@ -33,6 +33,8 @@ class BinanceHistoricalData:
self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache') self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache')
if not os.path.exists(self.cache_dir): if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir) os.makedirs(self.cache_dir)
# Timestamp of last data update
self.last_update = None
def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000): def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000):
""" """
@ -61,15 +63,18 @@ class BinanceHistoricalData:
interval = interval_map.get(interval_seconds, "1h") interval = interval_map.get(interval_seconds, "1h")
# Format symbol for Binance API (remove slash) # Format symbol for Binance API (remove slash)
formatted_symbol = symbol.replace("/", "") formatted_symbol = symbol.replace("/", "").lower()
# Check if we have cached data first # Check if we have cached data first
cache_file = self._get_cache_filename(formatted_symbol, interval) cache_file = self._get_cache_filename(formatted_symbol, interval)
cached_data = self._load_from_cache(formatted_symbol, interval) cached_data = self._load_from_cache(formatted_symbol, interval)
# If we have cached data that's recent enough, use it
if cached_data is not None and len(cached_data) >= limit: if cached_data is not None and len(cached_data) >= limit:
logger.info(f"Using cached historical data for {symbol} ({interval})") cache_age_minutes = (datetime.now() - self.last_update).total_seconds() / 60 if self.last_update else 60
return cached_data if cache_age_minutes < 15: # Only use cache if it's less than 15 minutes old
logger.info(f"Using cached historical data for {symbol} ({interval})")
return cached_data
try: try:
# Build URL for klines endpoint # Build URL for klines endpoint
@ -106,6 +111,7 @@ class BinanceHistoricalData:
# Save to cache for future use # Save to cache for future use
self._save_to_cache(df, formatted_symbol, interval) self._save_to_cache(df, formatted_symbol, interval)
self.last_update = datetime.now()
logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})") logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})")
return df return df
@ -594,11 +600,13 @@ class TickStorage:
logger.error(f"Error loading ticks from file: {str(e)}") logger.error(f"Error loading ticks from file: {str(e)}")
def load_historical_data(self, historical_data, symbol): def load_historical_data(self, historical_data, symbol):
"""Load historical data""" """Load historical data for all timeframes"""
try: try:
# Load data for different timeframes # Load data for different timeframes
timeframes = [ timeframes = [
(1, '1s'), # 1 second (1, '1s'), # 1 second - limit to 20 minutes (1200 seconds)
(5, '5s'), # 5 seconds
(15, '15s'), # 15 seconds
(60, '1m'), # 1 minute (60, '1m'), # 1 minute
(300, '5m'), # 5 minutes (300, '5m'), # 5 minutes
(900, '15m'), # 15 minutes (900, '15m'), # 15 minutes
@ -611,9 +619,9 @@ class TickStorage:
# Set appropriate limits based on timeframe # Set appropriate limits based on timeframe
limit = 1000 # Default limit = 1000 # Default
if interval_seconds == 1: if interval_seconds == 1:
limit = 500 # 1s is too much data, limit to 500 limit = 1200 # 1s data - limit to 20 minutes as requested
elif interval_seconds < 60: elif interval_seconds < 60:
limit = 750 # For seconds-level data limit = 500 # For seconds-level data
elif interval_seconds < 300: elif interval_seconds < 300:
limit = 1000 # 1m limit = 1000 # 1m
elif interval_seconds < 900: elif interval_seconds < 900:
@ -623,34 +631,86 @@ class TickStorage:
else: else:
limit = 200 # hourly/daily data limit = 200 # hourly/daily data
df = historical_data.get_historical_candles(symbol, interval_seconds, limit) try:
if df is not None and not df.empty: # For 1s data, we might need to generate it from 1m data
logger.info(f"Loaded {len(df)} historical candles for {symbol} ({interval_key})") if interval_seconds == 1:
# Get 1m data first
df_1m = historical_data.get_historical_candles(symbol, 60, 60) # Get 60 minutes of 1m data
if df_1m is not None and not df_1m.empty:
# Create simulated 1s data from 1m data
simulated_1s = []
for _, row in df_1m.iterrows():
# For each 1m candle, create 60 1s candles
start_time = row['timestamp']
for i in range(60):
# Calculate second-level timestamp
second_time = start_time + timedelta(seconds=i)
# Convert to our candle format and store # Create candle with random price movement around close price
for _, row in df.iterrows(): close_price = row['close']
candle = { price_range = (row['high'] - row['low']) / 60 # Reduced range
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
self.candles[interval_key].append(candle)
# For 1m and above, also use the close price to simulate ticks # Interpolate price - gradual movement from open to close
# but don't do this for seconds-level data as it creates too many ticks progress = i / 60
if interval_seconds >= 60 and interval_key == '1m': interp_price = row['open'] + (row['close'] - row['open']) * progress
self.add_tick(
price=row['close'],
volume=row['volume'],
timestamp=row['timestamp']
)
# Update latest price from most recent candle # Add some small random movement
if len(df) > 0: random_factor = np.random.normal(0, price_range * 0.5)
self.latest_price = df.iloc[-1]['close'] s_price = max(0, interp_price + random_factor)
# Create 1s candle
s_candle = {
'timestamp': second_time,
'open': s_price,
'high': s_price * 1.0001, # Tiny movement
'low': s_price * 0.9999, # Tiny movement
'close': s_price,
'volume': row['volume'] / 60 # Distribute volume
}
simulated_1s.append(s_candle)
# Add the simulated 1s candles to our candles storage
self.candles['1s'] = simulated_1s
logger.info(f"Generated {len(simulated_1s)} simulated 1s candles for {symbol}")
else:
# Load normal historical data
df = historical_data.get_historical_candles(symbol, interval_seconds, limit)
if df is not None and not df.empty:
logger.info(f"Loaded {len(df)} historical candles for {symbol} ({interval_key})")
# Convert to our candle format and store
candles = []
for _, row in df.iterrows():
candle = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
candles.append(candle)
# Set the candles for this timeframe
self.candles[interval_key] = candles
# For 1m data, also use it to generate tick data
if interval_key == '1m':
for candle in candles[-20:]: # Use only the last 20 candles for tick data
self.add_tick(
price=candle['close'],
volume=candle['volume'] / 10, # Distribute volume
timestamp=candle['timestamp']
)
# Set latest price from most recent candle
if candles:
self.latest_price = candles[-1]['close']
logger.info(f"Set latest price to ${self.latest_price:.2f} from historical data")
except Exception as e:
logger.error(f"Error loading {interval_key} data: {e}")
continue
logger.info(f"Completed loading historical data for {symbol}") logger.info(f"Completed loading historical data for {symbol}")
@ -926,125 +986,137 @@ class RealTimeChart:
return {}, {}, [], "Error", "$0.00", "$0.00" return {}, {}, [], "Error", "$0.00", "$0.00"
def _update_main_chart(self, interval=1): def _update_main_chart(self, interval=1):
"""Update the main chart with the selected timeframe""" """Update the main chart with OHLC data"""
try: try:
# Get candle data for the selected interval # Get candles for the interval
candles = self.get_candles(interval_seconds=interval) interval_key = self._get_interval_key(interval)
if not candles or len(candles) == 0: # Make sure we have data for this interval
# Return empty chart if no data if interval_key not in self.tick_storage.candles or not self.tick_storage.candles[interval_key]:
logger.warning(f"No candle data available for {interval_key}")
# Return empty figure with a message
fig = go.Figure()
fig.add_annotation(
text=f"No data available for {interval_key}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False
)
fig.update_layout(title=f"{self.symbol} - {interval_key}")
return fig
# For rendering, limit to the last 500 candles for performance
candles = self.tick_storage.candles[interval_key][-500:]
# Ensure we have at least 1 candle
if not candles:
logger.warning(f"No historical candles available for {interval_key}")
return go.Figure() return go.Figure()
# Create the candlestick chart # Extract OHLC values
fig = go.Figure() timestamps = [candle['timestamp'] for candle in candles]
opens = [candle['open'] for candle in candles]
highs = [candle['high'] for candle in candles]
lows = [candle['low'] for candle in candles]
closes = [candle['close'] for candle in candles]
volumes = [candle['volume'] for candle in candles]
# Create figure
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
vertical_spacing=0.02,
row_heights=[0.8, 0.2],
specs=[[{"type": "candlestick"}],
[{"type": "bar"}]])
# Add candlestick trace # Add candlestick trace
fig.add_trace(go.Candlestick( fig.add_trace(go.Candlestick(
x=[c['timestamp'] for c in candles], x=timestamps,
open=[c['open'] for c in candles], open=opens,
high=[c['high'] for c in candles], high=highs,
low=[c['low'] for c in candles], low=lows,
close=[c['close'] for c in candles], close=closes,
name='Price' name='OHLC',
)) increasing_line_color='rgba(0, 180, 0, 0.7)',
decreasing_line_color='rgba(255, 0, 0, 0.7)',
), row=1, col=1)
# Add volume as a bar chart below # Add volume bars
fig.add_trace(go.Bar( fig.add_trace(go.Bar(
x=[c['timestamp'] for c in candles], x=timestamps,
y=[c['volume'] for c in candles], y=volumes,
name='Volume', name='Volume',
marker=dict( marker_color='rgba(100, 100, 255, 0.5)'
color='rgba(0, 0, 255, 0.5)', ), row=2, col=1)
),
opacity=0.5,
yaxis='y2'
))
# Add buy/sell markers for trades # Add trading markers if available
if hasattr(self, 'positions') and self.positions: if hasattr(self, 'positions') and self.positions:
buy_times = [] # Get last 100 positions for display (to avoid too many markers)
positions = self.positions[-100:]
buy_timestamps = []
buy_prices = [] buy_prices = []
sell_times = [] sell_timestamps = []
sell_prices = [] sell_prices = []
# Use all positions for chart display for pos in positions:
# Filter recent ones based on visible time range if pos.action == 'BUY':
now = datetime.now() buy_timestamps.append(pos.entry_timestamp)
time_limit = now - timedelta(hours=24) # Show at most 24h of trades buy_prices.append(pos.entry_price)
elif pos.action == 'SELL':
sell_timestamps.append(pos.entry_timestamp) # Using entry_time for consistency
sell_prices.append(pos.entry_price) # Using entry_price for consistency
for position in self.positions: # Add buy markers
if position.entry_timestamp > time_limit: if buy_timestamps:
if position.action == "BUY":
buy_times.append(position.entry_timestamp)
buy_prices.append(position.entry_price)
elif position.action == "SELL" and position.exit_timestamp:
sell_times.append(position.exit_timestamp)
sell_prices.append(position.exit_price)
# Add buy markers (green triangles pointing up)
if buy_times:
fig.add_trace(go.Scatter( fig.add_trace(go.Scatter(
x=buy_times, x=buy_timestamps,
y=buy_prices, y=buy_prices,
mode='markers', mode='markers',
name='Buy', name='Buy',
marker=dict( marker=dict(
symbol='triangle-up', symbol='triangle-up',
size=10, size=15,
color='green', color='rgba(0, 180, 0, 0.8)',
line=dict(width=1, color='black') line=dict(width=1, color='rgba(0, 180, 0, 1)')
) )
)) ), row=1, col=1)
# Add sell markers (red triangles pointing down) # Add sell markers
if sell_times: if sell_timestamps:
fig.add_trace(go.Scatter( fig.add_trace(go.Scatter(
x=sell_times, x=sell_timestamps,
y=sell_prices, y=sell_prices,
mode='markers', mode='markers',
name='Sell', name='Sell',
marker=dict( marker=dict(
symbol='triangle-down', symbol='triangle-down',
size=10, size=15,
color='red', color='rgba(255, 0, 0, 0.8)',
line=dict(width=1, color='black') line=dict(width=1, color='rgba(255, 0, 0, 1)')
) )
)) ), row=1, col=1)
# Update layout # Update layout
timeframe_label = f"{interval}s" if interval < 60 else f"{interval//60}m" if interval < 3600 else f"{interval//3600}h"
fig.update_layout( fig.update_layout(
title=f'{self.symbol} Price ({timeframe_label})', title=f"{self.symbol} - {interval_key}",
xaxis_title='Time', xaxis_title="Time",
yaxis_title='Price', yaxis_title="Price",
template='plotly_dark',
xaxis_rangeslider_visible=False,
height=600, height=600,
hovermode='x unified', template="plotly_dark",
legend=dict( showlegend=True,
orientation="h", margin=dict(l=0, r=0, t=50, b=20),
yanchor="bottom", legend=dict(orientation="h", y=1.02, x=0.5, xanchor="center"),
y=1.02, uirevision='true' # To maintain zoom level on updates
xanchor="right",
x=1
),
yaxis=dict(
domain=[0.25, 1]
),
yaxis2=dict(
domain=[0, 0.2],
title='Volume'
),
) )
# Add timestamp to show when chart was last updated # Format Y-axis with enough decimal places for cryptocurrency
fig.add_annotation( fig.update_yaxes(tickformat=".2f")
text=f"Last updated: {datetime.now().strftime('%H:%M:%S')}",
xref="paper", yref="paper", # Format X-axis with date/time
x=0.98, y=0.01, fig.update_xaxes(
showarrow=False, rangeslider_visible=False,
font=dict(size=10, color="gray") rangebreaks=[
dict(bounds=["sat", "mon"]) # hide weekends
]
) )
return fig return fig
@ -1053,71 +1125,90 @@ class RealTimeChart:
logger.error(f"Error updating main chart: {str(e)}") logger.error(f"Error updating main chart: {str(e)}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return go.Figure() # Return empty figure on error # Return empty figure on error
return go.Figure()
def _update_secondary_charts(self): def _update_secondary_charts(self):
"""Create secondary charts with multiple timeframes (1m, 1h, 1d)""" """Update the secondary charts for other timeframes"""
try: try:
# Create subplot with 3 rows # For each timeframe, create a small chart
secondary_timeframes = ['1m', '5m', '15m', '1h']
if not all(tf in self.tick_storage.candles for tf in secondary_timeframes):
logger.warning("Not all secondary timeframes available")
# Return empty figure with a message
fig = make_subplots(rows=1, cols=4)
for i, tf in enumerate(secondary_timeframes, 1):
fig.add_annotation(
text=f"No data for {tf}",
xref=f"x{i}", yref=f"y{i}",
x=0.5, y=0.5, showarrow=False
)
return fig
# Create subplots for each timeframe
fig = make_subplots( fig = make_subplots(
rows=3, cols=1, rows=1, cols=4,
shared_xaxes=False, subplot_titles=secondary_timeframes,
vertical_spacing=0.05, shared_yaxes=True
subplot_titles=('1 Minute', '1 Hour', '1 Day')
) )
# Get data for each timeframe # Loop through each timeframe
candles_1m = self.get_candles(interval_seconds=60) for i, timeframe in enumerate(secondary_timeframes, 1):
candles_1h = self.get_candles(interval_seconds=3600) interval_key = timeframe
candles_1d = self.get_candles(interval_seconds=86400)
# 1-minute chart (row 1) # Get candles for this timeframe
if candles_1m and len(candles_1m) > 0: if interval_key in self.tick_storage.candles and self.tick_storage.candles[interval_key]:
fig.add_trace(go.Candlestick( # For rendering, limit to the last 100 candles for performance
x=[c['timestamp'] for c in candles_1m], candles = self.tick_storage.candles[interval_key][-100:]
open=[c['open'] for c in candles_1m],
high=[c['high'] for c in candles_1m],
low=[c['low'] for c in candles_1m],
close=[c['close'] for c in candles_1m],
name='1m Price',
showlegend=False
), row=1, col=1)
# 1-hour chart (row 2) if candles:
if candles_1h and len(candles_1h) > 0: # Extract OHLC values
fig.add_trace(go.Candlestick( timestamps = [candle['timestamp'] for candle in candles]
x=[c['timestamp'] for c in candles_1h], opens = [candle['open'] for candle in candles]
open=[c['open'] for c in candles_1h], highs = [candle['high'] for candle in candles]
high=[c['high'] for c in candles_1h], lows = [candle['low'] for candle in candles]
low=[c['low'] for c in candles_1h], closes = [candle['close'] for candle in candles]
close=[c['close'] for c in candles_1h],
name='1h Price',
showlegend=False
), row=2, col=1)
# 1-day chart (row 3) # Add candlestick trace
if candles_1d and len(candles_1d) > 0: fig.add_trace(go.Candlestick(
fig.add_trace(go.Candlestick( x=timestamps,
x=[c['timestamp'] for c in candles_1d], open=opens,
open=[c['open'] for c in candles_1d], high=highs,
high=[c['high'] for c in candles_1d], low=lows,
low=[c['low'] for c in candles_1d], close=closes,
close=[c['close'] for c in candles_1d], name=interval_key,
name='1d Price', increasing_line_color='rgba(0, 180, 0, 0.7)',
showlegend=False decreasing_line_color='rgba(255, 0, 0, 0.7)',
), row=3, col=1) showlegend=False
), row=1, col=i)
else:
# Add empty annotation if no data
fig.add_annotation(
text=f"No data for {interval_key}",
xref=f"x{i}", yref=f"y{i}",
x=0.5, y=0.5, showarrow=False
)
# Update layout # Update layout
fig.update_layout( fig.update_layout(
height=500, height=250,
template='plotly_dark', template="plotly_dark",
margin=dict(l=50, r=50, t=30, b=30),
showlegend=False, showlegend=False,
hovermode='x unified' margin=dict(l=0, r=0, t=30, b=0),
) )
# Disable rangesliders for cleaner look # Format Y-axis with 2 decimal places
fig.update_xaxes(rangeslider_visible=False) fig.update_yaxes(tickformat=".2f")
# Format X-axis to show only the date (no time)
for i in range(1, 5):
fig.update_xaxes(
row=1, col=i,
rangeslider_visible=False,
rangebreaks=[dict(bounds=["sat", "mon"])], # hide weekends
tickformat="%m-%d" # Show month-day only
)
return fig return fig
@ -1125,7 +1216,8 @@ class RealTimeChart:
logger.error(f"Error updating secondary charts: {str(e)}") logger.error(f"Error updating secondary charts: {str(e)}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return go.Figure() # Return empty figure on error # Return empty figure on error
return make_subplots(rows=1, cols=4)
def _get_position_list_rows(self): def _get_position_list_rows(self):
"""Generate HTML for the positions list (last 10 positions only)""" """Generate HTML for the positions list (last 10 positions only)"""
@ -1259,6 +1351,19 @@ class RealTimeChart:
async def start_websocket(self): async def start_websocket(self):
"""Start the websocket connection for real-time data""" """Start the websocket connection for real-time data"""
try: try:
# Load historical data first to ensure we have candles for all timeframes
logger.info(f"Loading historical data for {self.symbol}")
# Initialize a BinanceHistoricalData instance
historical_data = BinanceHistoricalData()
# Load historical data for display
self.tick_storage.load_historical_data(historical_data, self.symbol)
# Make sure we update the charts once with historical data before websocket starts
# Update all the charts with the initial historical data
self._update_chart_and_positions()
# Initialize websocket # Initialize websocket
self.websocket = ExchangeWebSocket(self.symbol) self.websocket = ExchangeWebSocket(self.symbol)
await self.websocket.connect() await self.websocket.connect()

View File

@ -299,14 +299,30 @@ class RLTrainingIntegrator:
# Create a custom environment class that includes our reward function modification # Create a custom environment class that includes our reward function modification
class EnhancedRLTradingEnvironment(RLTradingEnvironment): class EnhancedRLTradingEnvironment(RLTradingEnvironment):
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001): def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
"""Initialize with normalization parameters""" super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee, min_trade_interval)
super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee)
# Initialize integrator and chart references # Reference to integrator for tracking
self.integrator = None # Will be set after initialization self.integrator = None
self.chart = None # Will be set after initialization
# Make writer accessible to integrator callbacks # Store the original data for extrema analysis
self.writer = None # Will be set by train_rl self.original_data = None
# RNN signal integration
self.signal_interpreter = None
self.last_rnn_signals = []
self.rnn_signal_weight = 0.3 # Weight for RNN signals in decision making
# TensorBoard writer
self.writer = None
def set_integrator(self, integrator):
"""Set reference to integrator for callbacks"""
self.integrator = integrator
def set_signal_interpreter(self, signal_interpreter):
"""Set reference to signal interpreter for RNN signal integration"""
self.signal_interpreter = signal_interpreter
def set_tensorboard_writer(self, writer): def set_tensorboard_writer(self, writer):
"""Set the TensorBoard writer""" """Set the TensorBoard writer"""
@ -314,50 +330,133 @@ class RLTrainingIntegrator:
def _calculate_reward(self, action): def _calculate_reward(self, action):
"""Override the reward calculation with our enhanced version""" """Override the reward calculation with our enhanced version"""
# Get the original reward calculation result # Get current and next price
reward, pnl = super()._calculate_reward(action)
# Get current price (normalized from training data)
current_price = self.features_1m[self.current_step, -1] current_price = self.features_1m[self.current_step, -1]
next_price = self.features_1m[self.current_step + 1, -1]
# Get real market price if available # Default values
pnl = 0.0
reward = -0.0001 # Small negative reward to discourage excessive actions
# Get real market price if available (from integrator)
real_market_price = None real_market_price = None
if hasattr(self, 'chart') and self.chart and hasattr(self.chart, 'latest_price'): if self.integrator and hasattr(self.integrator, 'chart') and self.integrator.chart:
real_market_price = self.chart.latest_price if hasattr(self.integrator.chart, 'tick_storage'):
real_market_price = self.integrator.chart.tick_storage.get_latest_price()
# Pass through the integrator's reward modifier # Calculate base reward based on position and price change
if hasattr(self, 'integrator') and self.integrator is not None: if action == 0: # BUY
# Add price to history - use real market price if available # Apply fee directly as negative reward to discourage excessive trading
if real_market_price is not None: reward -= self.trading_fee
# For extrema detection, use a normalized version of the real price
# to keep scale consistent with the model's price history # Check if we already have a position
self.integrator.price_history.append(current_price) if self.integrator and self.integrator.current_position_size > 0:
reward -= 0.002 # Additional penalty for trying to buy when already in position
# If RNN signal available, incorporate it
if self.signal_interpreter and len(self.last_rnn_signals) > 0:
last_signal = self.last_rnn_signals[-1]
if last_signal['action'] == 'BUY':
# RNN also suggests BUY - boost reward
reward += 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
elif last_signal['action'] == 'SELL':
# RNN suggests opposite - reduce reward
reward -= 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
elif action == 1: # SELL
if self.integrator and self.integrator.current_position_size > 0:
# Calculate potential profit/loss
if self.integrator.entry_price:
price_to_use = real_market_price if real_market_price else current_price
pnl = (price_to_use - self.integrator.entry_price) / self.integrator.entry_price
# Base reward on actual PnL
reward = pnl * 10
# Apply fee as negative component
reward -= self.trading_fee
# If RNN signal available, incorporate it
if self.signal_interpreter and len(self.last_rnn_signals) > 0:
last_signal = self.last_rnn_signals[-1]
if last_signal['action'] == 'SELL':
# RNN also suggests SELL - boost reward
reward += 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
elif last_signal['action'] == 'BUY':
# RNN suggests opposite - reduce reward
reward -= 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
else: else:
self.integrator.price_history.append(current_price) # No position to sell - penalize
reward = -0.005
# Apply extrema-based reward modifications elif action == 2: # HOLD
if len(self.integrator.price_history) > 20: # Check if we're holding a profitable position
# Detect local extrema if self.integrator and self.integrator.current_position_size > 0 and self.integrator.entry_price:
tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema( price_to_use = real_market_price if real_market_price else current_price
self.integrator.price_history pnl = (price_to_use - self.integrator.entry_price) / self.integrator.entry_price
)
# Calculate additional rewards based on extrema # Encourage holding profitable positions
if action == 0 and bottoms_indices and bottoms_indices[-1] > len(self.integrator.price_history) - 5: if pnl > 0:
# Bonus for buying near bottoms reward = 0.0001 * pnl * 5 # Small positive reward for holding winner
reward += 0.01
if self.integrator.session_step % 50 == 0: # Log less frequently
# Display the real market price if available
display_price = real_market_price if real_market_price is not None else current_price
logger.info(f"BUY signal near bottom detected at price {display_price:.2f}! Adding bonus reward.")
elif action == 1 and tops_indices and tops_indices[-1] > len(self.integrator.price_history) - 5: # If position is very profitable, increase hold reward
# Bonus for selling near tops if pnl > 0.01: # Over 1% profit
reward += 0.01 reward *= 2
if self.integrator.session_step % 50 == 0: # Log less frequently else:
# Display the real market price if available # Small negative reward for holding losing position
display_price = real_market_price if real_market_price is not None else current_price reward = -0.0001 * abs(pnl) * 2
logger.info(f"SELL signal near top detected at price {display_price:.2f}! Adding bonus reward.")
# If RNN signal suggests HOLD, add small reward
if self.signal_interpreter and len(self.last_rnn_signals) > 0:
last_signal = self.last_rnn_signals[-1]
if last_signal['action'] == 'HOLD':
reward += 0.0001 * self.rnn_signal_weight
# Add price to history - use real market price if available
if real_market_price is not None:
# For extrema detection, use a normalized version of the real price
# to keep scale consistent with the model's price history
self.integrator.price_history.append(current_price)
else:
self.integrator.price_history.append(current_price)
# Apply extrema-based reward modifications
if len(self.integrator.price_history) > 20:
# Detect local extrema
tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema(
self.integrator.price_history
)
# Get current price and market context
current_price = self.integrator.price_history[-1]
# Check if we're near a local extrema (top or bottom)
is_near_bottom = any(i > len(self.integrator.price_history) - 5 for i in bottoms_indices)
is_near_top = any(i > len(self.integrator.price_history) - 5 for i in tops_indices)
# Modify reward based on action and extrema
if action == 0 and is_near_bottom: # BUY near bottom
logger.info("Buying near local bottom - adding bonus reward")
reward += 0.015 # Significant bonus
elif action == 0 and is_near_top: # BUY near top
logger.info("Buying near local top - applying penalty")
reward -= 0.01 # Penalty
elif action == 1 and is_near_top: # SELL near top
logger.info("Selling near local top - adding bonus reward")
reward += 0.015 # Significant bonus
elif action == 1 and is_near_bottom: # SELL near bottom
logger.info("Selling near local bottom - applying penalty")
reward -= 0.01 # Penalty
elif action == 2: # HOLD
if is_near_bottom and self.integrator.current_position_size > 0:
# Good to hold if we have positions at bottom
reward += 0.002 # Small bonus
elif is_near_top and self.integrator.current_position_size == 0:
# Good to hold if we have no positions at top
reward += 0.002 # Small bonus
# Limit extreme rewards
reward = max(min(reward, 0.5), -0.5)
return reward, pnl return reward, pnl
@ -758,7 +857,22 @@ def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"):
# For SELL actions, close the position with given PnL # For SELL actions, close the position with given PnL
if action == "SELL": if action == "SELL":
# Find the most recent BUY position that hasn't been closed
entry_position = None
entry_price = price # Default if no open position found
for pos in reversed(chart.positions):
if pos.action == "BUY" and pos.is_open:
entry_position = pos
entry_price = pos.entry_price
# Mark this position as closed
pos.close(price, timestamp)
break
# Close this sell position with the right prices
position.entry_price = entry_price # Use the found entry price
position.close(price, timestamp) position.close(price, timestamp)
# Use realistic PnL values rather than the enormous ones from the model # Use realistic PnL values rather than the enormous ones from the model
# Cap PnL to reasonable values based on position size and price # Cap PnL to reasonable values based on position size and price
max_reasonable_pnl = price * amount * 0.05 # Max 5% profit per trade max_reasonable_pnl = price * amount * 0.05 # Max 5% profit per trade