From 938eef8bc913ced3a89bd279f111af199c10e051 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 1 Apr 2025 21:37:08 +0300 Subject: [PATCH] charts more or less OK --- NN/train_rl.py | 32 ++- realtime.py | 451 +++++++++++++++++++++++--------------- train_rl_with_realtime.py | 208 ++++++++++++++---- 3 files changed, 465 insertions(+), 226 deletions(-) diff --git a/NN/train_rl.py b/NN/train_rl.py index d23bfca..fe71068 100644 --- a/NN/train_rl.py +++ b/NN/train_rl.py @@ -34,7 +34,7 @@ class RLTradingEnvironment(gym.Env): Reinforcement Learning environment for trading with technical indicators from multiple timeframes """ - def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001): + def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15): super().__init__() # Initialize attributes before parent class @@ -50,7 +50,8 @@ class RLTradingEnvironment(gym.Env): # Trading parameters self.initial_balance = 1.0 - self.trading_fee = trading_fee + self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%) + self.min_trade_interval = min_trade_interval # Minimum steps between trades # Define action and observation spaces self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold @@ -76,6 +77,7 @@ class RLTradingEnvironment(gym.Env): self.wins = 0 self.losses = 0 self.trade_history = [] + self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade # Get initial observation observation = self._get_observation() @@ -150,24 +152,40 @@ class RLTradingEnvironment(gym.Env): done = False profit_pct = None # Initialize profit_pct variable + # Check if enough time has passed since last trade + trade_interval = self.current_step - self.last_trade_step + trade_interval_penalty = 0 + # Execute action if action == 0: # BUY if self.position == 0: # Only buy if not already in position + # Apply extra penalty for trading too frequently + if trade_interval < self.min_trade_interval: + trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval) + # Still allow the trade but with penalty + self.position = self.balance * (1 - self.trading_fee) self.balance = 0 self.trades += 1 - reward = 0 # Neutral reward for entering position + reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty self.trade_entry_price = current_price + self.last_trade_step = self.current_step elif action == 1: # SELL if self.position > 0: # Only sell if in position + # Apply extra penalty for trading too frequently + if trade_interval < self.min_trade_interval: + trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval) + # Still allow the trade but with penalty + # Calculate position value at current price position_value = self.position * (1 + price_change) self.balance = position_value * (1 - self.trading_fee) # Calculate profit/loss from trade profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price - reward = profit_pct * 10 # Scale reward by profit percentage + # Scale reward by profit percentage and apply trade interval penalty + reward = (profit_pct * 10) + trade_interval_penalty # Update win/loss count if profit_pct > 0: @@ -179,11 +197,13 @@ class RLTradingEnvironment(gym.Env): self.trade_history.append({ 'entry_price': self.trade_entry_price, 'exit_price': next_price, - 'profit_pct': profit_pct + 'profit_pct': profit_pct, + 'trade_interval': trade_interval }) - # Reset position + # Reset position and update last trade step self.position = 0 + self.last_trade_step = self.current_step # else: (action == 2 - HOLD) - no position change diff --git a/realtime.py b/realtime.py index 8848c9f..d3ec3cb 100644 --- a/realtime.py +++ b/realtime.py @@ -33,6 +33,8 @@ class BinanceHistoricalData: self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache') if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) + # Timestamp of last data update + self.last_update = None def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000): """ @@ -61,15 +63,18 @@ class BinanceHistoricalData: interval = interval_map.get(interval_seconds, "1h") # Format symbol for Binance API (remove slash) - formatted_symbol = symbol.replace("/", "") + formatted_symbol = symbol.replace("/", "").lower() # Check if we have cached data first cache_file = self._get_cache_filename(formatted_symbol, interval) cached_data = self._load_from_cache(formatted_symbol, interval) + # If we have cached data that's recent enough, use it if cached_data is not None and len(cached_data) >= limit: - logger.info(f"Using cached historical data for {symbol} ({interval})") - return cached_data + cache_age_minutes = (datetime.now() - self.last_update).total_seconds() / 60 if self.last_update else 60 + if cache_age_minutes < 15: # Only use cache if it's less than 15 minutes old + logger.info(f"Using cached historical data for {symbol} ({interval})") + return cached_data try: # Build URL for klines endpoint @@ -106,6 +111,7 @@ class BinanceHistoricalData: # Save to cache for future use self._save_to_cache(df, formatted_symbol, interval) + self.last_update = datetime.now() logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})") return df @@ -594,11 +600,13 @@ class TickStorage: logger.error(f"Error loading ticks from file: {str(e)}") def load_historical_data(self, historical_data, symbol): - """Load historical data""" + """Load historical data for all timeframes""" try: # Load data for different timeframes timeframes = [ - (1, '1s'), # 1 second + (1, '1s'), # 1 second - limit to 20 minutes (1200 seconds) + (5, '5s'), # 5 seconds + (15, '15s'), # 15 seconds (60, '1m'), # 1 minute (300, '5m'), # 5 minutes (900, '15m'), # 15 minutes @@ -611,9 +619,9 @@ class TickStorage: # Set appropriate limits based on timeframe limit = 1000 # Default if interval_seconds == 1: - limit = 500 # 1s is too much data, limit to 500 + limit = 1200 # 1s data - limit to 20 minutes as requested elif interval_seconds < 60: - limit = 750 # For seconds-level data + limit = 500 # For seconds-level data elif interval_seconds < 300: limit = 1000 # 1m elif interval_seconds < 900: @@ -623,35 +631,87 @@ class TickStorage: else: limit = 200 # hourly/daily data - df = historical_data.get_historical_candles(symbol, interval_seconds, limit) - if df is not None and not df.empty: - logger.info(f"Loaded {len(df)} historical candles for {symbol} ({interval_key})") - - # Convert to our candle format and store - for _, row in df.iterrows(): - candle = { - 'timestamp': row['timestamp'], - 'open': row['open'], - 'high': row['high'], - 'low': row['low'], - 'close': row['close'], - 'volume': row['volume'] - } - self.candles[interval_key].append(candle) - - # For 1m and above, also use the close price to simulate ticks - # but don't do this for seconds-level data as it creates too many ticks - if interval_seconds >= 60 and interval_key == '1m': - self.add_tick( - price=row['close'], - volume=row['volume'], - timestamp=row['timestamp'] - ) - - # Update latest price from most recent candle - if len(df) > 0: - self.latest_price = df.iloc[-1]['close'] - + try: + # For 1s data, we might need to generate it from 1m data + if interval_seconds == 1: + # Get 1m data first + df_1m = historical_data.get_historical_candles(symbol, 60, 60) # Get 60 minutes of 1m data + if df_1m is not None and not df_1m.empty: + # Create simulated 1s data from 1m data + simulated_1s = [] + for _, row in df_1m.iterrows(): + # For each 1m candle, create 60 1s candles + start_time = row['timestamp'] + for i in range(60): + # Calculate second-level timestamp + second_time = start_time + timedelta(seconds=i) + + # Create candle with random price movement around close price + close_price = row['close'] + price_range = (row['high'] - row['low']) / 60 # Reduced range + + # Interpolate price - gradual movement from open to close + progress = i / 60 + interp_price = row['open'] + (row['close'] - row['open']) * progress + + # Add some small random movement + random_factor = np.random.normal(0, price_range * 0.5) + s_price = max(0, interp_price + random_factor) + + # Create 1s candle + s_candle = { + 'timestamp': second_time, + 'open': s_price, + 'high': s_price * 1.0001, # Tiny movement + 'low': s_price * 0.9999, # Tiny movement + 'close': s_price, + 'volume': row['volume'] / 60 # Distribute volume + } + simulated_1s.append(s_candle) + + # Add the simulated 1s candles to our candles storage + self.candles['1s'] = simulated_1s + logger.info(f"Generated {len(simulated_1s)} simulated 1s candles for {symbol}") + else: + # Load normal historical data + df = historical_data.get_historical_candles(symbol, interval_seconds, limit) + if df is not None and not df.empty: + logger.info(f"Loaded {len(df)} historical candles for {symbol} ({interval_key})") + + # Convert to our candle format and store + candles = [] + for _, row in df.iterrows(): + candle = { + 'timestamp': row['timestamp'], + 'open': row['open'], + 'high': row['high'], + 'low': row['low'], + 'close': row['close'], + 'volume': row['volume'] + } + candles.append(candle) + + # Set the candles for this timeframe + self.candles[interval_key] = candles + + # For 1m data, also use it to generate tick data + if interval_key == '1m': + for candle in candles[-20:]: # Use only the last 20 candles for tick data + self.add_tick( + price=candle['close'], + volume=candle['volume'] / 10, # Distribute volume + timestamp=candle['timestamp'] + ) + + # Set latest price from most recent candle + if candles: + self.latest_price = candles[-1]['close'] + logger.info(f"Set latest price to ${self.latest_price:.2f} from historical data") + + except Exception as e: + logger.error(f"Error loading {interval_key} data: {e}") + continue + logger.info(f"Completed loading historical data for {symbol}") except Exception as e: @@ -926,125 +986,137 @@ class RealTimeChart: return {}, {}, [], "Error", "$0.00", "$0.00" def _update_main_chart(self, interval=1): - """Update the main chart with the selected timeframe""" + """Update the main chart with OHLC data""" try: - # Get candle data for the selected interval - candles = self.get_candles(interval_seconds=interval) + # Get candles for the interval + interval_key = self._get_interval_key(interval) - if not candles or len(candles) == 0: - # Return empty chart if no data + # Make sure we have data for this interval + if interval_key not in self.tick_storage.candles or not self.tick_storage.candles[interval_key]: + logger.warning(f"No candle data available for {interval_key}") + # Return empty figure with a message + fig = go.Figure() + fig.add_annotation( + text=f"No data available for {interval_key}", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False + ) + fig.update_layout(title=f"{self.symbol} - {interval_key}") + return fig + + # For rendering, limit to the last 500 candles for performance + candles = self.tick_storage.candles[interval_key][-500:] + + # Ensure we have at least 1 candle + if not candles: + logger.warning(f"No historical candles available for {interval_key}") return go.Figure() - # Create the candlestick chart - fig = go.Figure() + # Extract OHLC values + timestamps = [candle['timestamp'] for candle in candles] + opens = [candle['open'] for candle in candles] + highs = [candle['high'] for candle in candles] + lows = [candle['low'] for candle in candles] + closes = [candle['close'] for candle in candles] + volumes = [candle['volume'] for candle in candles] + + # Create figure + fig = make_subplots(rows=2, cols=1, shared_xaxes=True, + vertical_spacing=0.02, + row_heights=[0.8, 0.2], + specs=[[{"type": "candlestick"}], + [{"type": "bar"}]]) # Add candlestick trace fig.add_trace(go.Candlestick( - x=[c['timestamp'] for c in candles], - open=[c['open'] for c in candles], - high=[c['high'] for c in candles], - low=[c['low'] for c in candles], - close=[c['close'] for c in candles], - name='Price' - )) + x=timestamps, + open=opens, + high=highs, + low=lows, + close=closes, + name='OHLC', + increasing_line_color='rgba(0, 180, 0, 0.7)', + decreasing_line_color='rgba(255, 0, 0, 0.7)', + ), row=1, col=1) - # Add volume as a bar chart below + # Add volume bars fig.add_trace(go.Bar( - x=[c['timestamp'] for c in candles], - y=[c['volume'] for c in candles], + x=timestamps, + y=volumes, name='Volume', - marker=dict( - color='rgba(0, 0, 255, 0.5)', - ), - opacity=0.5, - yaxis='y2' - )) + marker_color='rgba(100, 100, 255, 0.5)' + ), row=2, col=1) - # Add buy/sell markers for trades + # Add trading markers if available if hasattr(self, 'positions') and self.positions: - buy_times = [] + # Get last 100 positions for display (to avoid too many markers) + positions = self.positions[-100:] + + buy_timestamps = [] buy_prices = [] - sell_times = [] + sell_timestamps = [] sell_prices = [] - # Use all positions for chart display - # Filter recent ones based on visible time range - now = datetime.now() - time_limit = now - timedelta(hours=24) # Show at most 24h of trades - - for position in self.positions: - if position.entry_timestamp > time_limit: - if position.action == "BUY": - buy_times.append(position.entry_timestamp) - buy_prices.append(position.entry_price) - elif position.action == "SELL" and position.exit_timestamp: - sell_times.append(position.exit_timestamp) - sell_prices.append(position.exit_price) - - # Add buy markers (green triangles pointing up) - if buy_times: + for pos in positions: + if pos.action == 'BUY': + buy_timestamps.append(pos.entry_timestamp) + buy_prices.append(pos.entry_price) + elif pos.action == 'SELL': + sell_timestamps.append(pos.entry_timestamp) # Using entry_time for consistency + sell_prices.append(pos.entry_price) # Using entry_price for consistency + + # Add buy markers + if buy_timestamps: fig.add_trace(go.Scatter( - x=buy_times, + x=buy_timestamps, y=buy_prices, mode='markers', name='Buy', marker=dict( symbol='triangle-up', - size=10, - color='green', - line=dict(width=1, color='black') + size=15, + color='rgba(0, 180, 0, 0.8)', + line=dict(width=1, color='rgba(0, 180, 0, 1)') ) - )) + ), row=1, col=1) - # Add sell markers (red triangles pointing down) - if sell_times: + # Add sell markers + if sell_timestamps: fig.add_trace(go.Scatter( - x=sell_times, + x=sell_timestamps, y=sell_prices, mode='markers', name='Sell', marker=dict( symbol='triangle-down', - size=10, - color='red', - line=dict(width=1, color='black') + size=15, + color='rgba(255, 0, 0, 0.8)', + line=dict(width=1, color='rgba(255, 0, 0, 1)') ) - )) + ), row=1, col=1) # Update layout - timeframe_label = f"{interval}s" if interval < 60 else f"{interval//60}m" if interval < 3600 else f"{interval//3600}h" - fig.update_layout( - title=f'{self.symbol} Price ({timeframe_label})', - xaxis_title='Time', - yaxis_title='Price', - template='plotly_dark', - xaxis_rangeslider_visible=False, + title=f"{self.symbol} - {interval_key}", + xaxis_title="Time", + yaxis_title="Price", height=600, - hovermode='x unified', - legend=dict( - orientation="h", - yanchor="bottom", - y=1.02, - xanchor="right", - x=1 - ), - yaxis=dict( - domain=[0.25, 1] - ), - yaxis2=dict( - domain=[0, 0.2], - title='Volume' - ), + template="plotly_dark", + showlegend=True, + margin=dict(l=0, r=0, t=50, b=20), + legend=dict(orientation="h", y=1.02, x=0.5, xanchor="center"), + uirevision='true' # To maintain zoom level on updates ) - # Add timestamp to show when chart was last updated - fig.add_annotation( - text=f"Last updated: {datetime.now().strftime('%H:%M:%S')}", - xref="paper", yref="paper", - x=0.98, y=0.01, - showarrow=False, - font=dict(size=10, color="gray") + # Format Y-axis with enough decimal places for cryptocurrency + fig.update_yaxes(tickformat=".2f") + + # Format X-axis with date/time + fig.update_xaxes( + rangeslider_visible=False, + rangebreaks=[ + dict(bounds=["sat", "mon"]) # hide weekends + ] ) return fig @@ -1053,71 +1125,90 @@ class RealTimeChart: logger.error(f"Error updating main chart: {str(e)}") import traceback logger.error(traceback.format_exc()) - return go.Figure() # Return empty figure on error + # Return empty figure on error + return go.Figure() def _update_secondary_charts(self): - """Create secondary charts with multiple timeframes (1m, 1h, 1d)""" + """Update the secondary charts for other timeframes""" try: - # Create subplot with 3 rows + # For each timeframe, create a small chart + secondary_timeframes = ['1m', '5m', '15m', '1h'] + + if not all(tf in self.tick_storage.candles for tf in secondary_timeframes): + logger.warning("Not all secondary timeframes available") + # Return empty figure with a message + fig = make_subplots(rows=1, cols=4) + for i, tf in enumerate(secondary_timeframes, 1): + fig.add_annotation( + text=f"No data for {tf}", + xref=f"x{i}", yref=f"y{i}", + x=0.5, y=0.5, showarrow=False + ) + return fig + + # Create subplots for each timeframe fig = make_subplots( - rows=3, cols=1, - shared_xaxes=False, - vertical_spacing=0.05, - subplot_titles=('1 Minute', '1 Hour', '1 Day') + rows=1, cols=4, + subplot_titles=secondary_timeframes, + shared_yaxes=True ) - # Get data for each timeframe - candles_1m = self.get_candles(interval_seconds=60) - candles_1h = self.get_candles(interval_seconds=3600) - candles_1d = self.get_candles(interval_seconds=86400) - - # 1-minute chart (row 1) - if candles_1m and len(candles_1m) > 0: - fig.add_trace(go.Candlestick( - x=[c['timestamp'] for c in candles_1m], - open=[c['open'] for c in candles_1m], - high=[c['high'] for c in candles_1m], - low=[c['low'] for c in candles_1m], - close=[c['close'] for c in candles_1m], - name='1m Price', - showlegend=False - ), row=1, col=1) - - # 1-hour chart (row 2) - if candles_1h and len(candles_1h) > 0: - fig.add_trace(go.Candlestick( - x=[c['timestamp'] for c in candles_1h], - open=[c['open'] for c in candles_1h], - high=[c['high'] for c in candles_1h], - low=[c['low'] for c in candles_1h], - close=[c['close'] for c in candles_1h], - name='1h Price', - showlegend=False - ), row=2, col=1) - - # 1-day chart (row 3) - if candles_1d and len(candles_1d) > 0: - fig.add_trace(go.Candlestick( - x=[c['timestamp'] for c in candles_1d], - open=[c['open'] for c in candles_1d], - high=[c['high'] for c in candles_1d], - low=[c['low'] for c in candles_1d], - close=[c['close'] for c in candles_1d], - name='1d Price', - showlegend=False - ), row=3, col=1) + # Loop through each timeframe + for i, timeframe in enumerate(secondary_timeframes, 1): + interval_key = timeframe + + # Get candles for this timeframe + if interval_key in self.tick_storage.candles and self.tick_storage.candles[interval_key]: + # For rendering, limit to the last 100 candles for performance + candles = self.tick_storage.candles[interval_key][-100:] + + if candles: + # Extract OHLC values + timestamps = [candle['timestamp'] for candle in candles] + opens = [candle['open'] for candle in candles] + highs = [candle['high'] for candle in candles] + lows = [candle['low'] for candle in candles] + closes = [candle['close'] for candle in candles] + + # Add candlestick trace + fig.add_trace(go.Candlestick( + x=timestamps, + open=opens, + high=highs, + low=lows, + close=closes, + name=interval_key, + increasing_line_color='rgba(0, 180, 0, 0.7)', + decreasing_line_color='rgba(255, 0, 0, 0.7)', + showlegend=False + ), row=1, col=i) + else: + # Add empty annotation if no data + fig.add_annotation( + text=f"No data for {interval_key}", + xref=f"x{i}", yref=f"y{i}", + x=0.5, y=0.5, showarrow=False + ) # Update layout fig.update_layout( - height=500, - template='plotly_dark', - margin=dict(l=50, r=50, t=30, b=30), + height=250, + template="plotly_dark", showlegend=False, - hovermode='x unified' + margin=dict(l=0, r=0, t=30, b=0), ) - # Disable rangesliders for cleaner look - fig.update_xaxes(rangeslider_visible=False) + # Format Y-axis with 2 decimal places + fig.update_yaxes(tickformat=".2f") + + # Format X-axis to show only the date (no time) + for i in range(1, 5): + fig.update_xaxes( + row=1, col=i, + rangeslider_visible=False, + rangebreaks=[dict(bounds=["sat", "mon"])], # hide weekends + tickformat="%m-%d" # Show month-day only + ) return fig @@ -1125,7 +1216,8 @@ class RealTimeChart: logger.error(f"Error updating secondary charts: {str(e)}") import traceback logger.error(traceback.format_exc()) - return go.Figure() # Return empty figure on error + # Return empty figure on error + return make_subplots(rows=1, cols=4) def _get_position_list_rows(self): """Generate HTML for the positions list (last 10 positions only)""" @@ -1259,6 +1351,19 @@ class RealTimeChart: async def start_websocket(self): """Start the websocket connection for real-time data""" try: + # Load historical data first to ensure we have candles for all timeframes + logger.info(f"Loading historical data for {self.symbol}") + + # Initialize a BinanceHistoricalData instance + historical_data = BinanceHistoricalData() + + # Load historical data for display + self.tick_storage.load_historical_data(historical_data, self.symbol) + + # Make sure we update the charts once with historical data before websocket starts + # Update all the charts with the initial historical data + self._update_chart_and_positions() + # Initialize websocket self.websocket = ExchangeWebSocket(self.symbol) await self.websocket.connect() diff --git a/train_rl_with_realtime.py b/train_rl_with_realtime.py index 40c6560..9268de0 100644 --- a/train_rl_with_realtime.py +++ b/train_rl_with_realtime.py @@ -299,14 +299,30 @@ class RLTrainingIntegrator: # Create a custom environment class that includes our reward function modification class EnhancedRLTradingEnvironment(RLTradingEnvironment): - def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001): - """Initialize with normalization parameters""" - super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee) - # Initialize integrator and chart references - self.integrator = None # Will be set after initialization - self.chart = None # Will be set after initialization - # Make writer accessible to integrator callbacks - self.writer = None # Will be set by train_rl + def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15): + super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee, min_trade_interval) + + # Reference to integrator for tracking + self.integrator = None + + # Store the original data for extrema analysis + self.original_data = None + + # RNN signal integration + self.signal_interpreter = None + self.last_rnn_signals = [] + self.rnn_signal_weight = 0.3 # Weight for RNN signals in decision making + + # TensorBoard writer + self.writer = None + + def set_integrator(self, integrator): + """Set reference to integrator for callbacks""" + self.integrator = integrator + + def set_signal_interpreter(self, signal_interpreter): + """Set reference to signal interpreter for RNN signal integration""" + self.signal_interpreter = signal_interpreter def set_tensorboard_writer(self, writer): """Set the TensorBoard writer""" @@ -314,51 +330,134 @@ class RLTrainingIntegrator: def _calculate_reward(self, action): """Override the reward calculation with our enhanced version""" - # Get the original reward calculation result - reward, pnl = super()._calculate_reward(action) - - # Get current price (normalized from training data) + # Get current and next price current_price = self.features_1m[self.current_step, -1] + next_price = self.features_1m[self.current_step + 1, -1] - # Get real market price if available + # Default values + pnl = 0.0 + reward = -0.0001 # Small negative reward to discourage excessive actions + + # Get real market price if available (from integrator) real_market_price = None - if hasattr(self, 'chart') and self.chart and hasattr(self.chart, 'latest_price'): - real_market_price = self.chart.latest_price + if self.integrator and hasattr(self.integrator, 'chart') and self.integrator.chart: + if hasattr(self.integrator.chart, 'tick_storage'): + real_market_price = self.integrator.chart.tick_storage.get_latest_price() - # Pass through the integrator's reward modifier - if hasattr(self, 'integrator') and self.integrator is not None: - # Add price to history - use real market price if available - if real_market_price is not None: - # For extrema detection, use a normalized version of the real price - # to keep scale consistent with the model's price history - self.integrator.price_history.append(current_price) + # Calculate base reward based on position and price change + if action == 0: # BUY + # Apply fee directly as negative reward to discourage excessive trading + reward -= self.trading_fee + + # Check if we already have a position + if self.integrator and self.integrator.current_position_size > 0: + reward -= 0.002 # Additional penalty for trying to buy when already in position + + # If RNN signal available, incorporate it + if self.signal_interpreter and len(self.last_rnn_signals) > 0: + last_signal = self.last_rnn_signals[-1] + if last_signal['action'] == 'BUY': + # RNN also suggests BUY - boost reward + reward += 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0) + elif last_signal['action'] == 'SELL': + # RNN suggests opposite - reduce reward + reward -= 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0) + + elif action == 1: # SELL + if self.integrator and self.integrator.current_position_size > 0: + # Calculate potential profit/loss + if self.integrator.entry_price: + price_to_use = real_market_price if real_market_price else current_price + pnl = (price_to_use - self.integrator.entry_price) / self.integrator.entry_price + + # Base reward on actual PnL + reward = pnl * 10 + + # Apply fee as negative component + reward -= self.trading_fee + + # If RNN signal available, incorporate it + if self.signal_interpreter and len(self.last_rnn_signals) > 0: + last_signal = self.last_rnn_signals[-1] + if last_signal['action'] == 'SELL': + # RNN also suggests SELL - boost reward + reward += 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0) + elif last_signal['action'] == 'BUY': + # RNN suggests opposite - reduce reward + reward -= 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0) else: - self.integrator.price_history.append(current_price) + # No position to sell - penalize + reward = -0.005 + + elif action == 2: # HOLD + # Check if we're holding a profitable position + if self.integrator and self.integrator.current_position_size > 0 and self.integrator.entry_price: + price_to_use = real_market_price if real_market_price else current_price + pnl = (price_to_use - self.integrator.entry_price) / self.integrator.entry_price + + # Encourage holding profitable positions + if pnl > 0: + reward = 0.0001 * pnl * 5 # Small positive reward for holding winner + + # If position is very profitable, increase hold reward + if pnl > 0.01: # Over 1% profit + reward *= 2 + else: + # Small negative reward for holding losing position + reward = -0.0001 * abs(pnl) * 2 + + # If RNN signal suggests HOLD, add small reward + if self.signal_interpreter and len(self.last_rnn_signals) > 0: + last_signal = self.last_rnn_signals[-1] + if last_signal['action'] == 'HOLD': + reward += 0.0001 * self.rnn_signal_weight + + # Add price to history - use real market price if available + if real_market_price is not None: + # For extrema detection, use a normalized version of the real price + # to keep scale consistent with the model's price history + self.integrator.price_history.append(current_price) + else: + self.integrator.price_history.append(current_price) + + # Apply extrema-based reward modifications + if len(self.integrator.price_history) > 20: + # Detect local extrema + tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema( + self.integrator.price_history + ) + + # Get current price and market context + current_price = self.integrator.price_history[-1] + + # Check if we're near a local extrema (top or bottom) + is_near_bottom = any(i > len(self.integrator.price_history) - 5 for i in bottoms_indices) + is_near_top = any(i > len(self.integrator.price_history) - 5 for i in tops_indices) + + # Modify reward based on action and extrema + if action == 0 and is_near_bottom: # BUY near bottom + logger.info("Buying near local bottom - adding bonus reward") + reward += 0.015 # Significant bonus + elif action == 0 and is_near_top: # BUY near top + logger.info("Buying near local top - applying penalty") + reward -= 0.01 # Penalty + elif action == 1 and is_near_top: # SELL near top + logger.info("Selling near local top - adding bonus reward") + reward += 0.015 # Significant bonus + elif action == 1 and is_near_bottom: # SELL near bottom + logger.info("Selling near local bottom - applying penalty") + reward -= 0.01 # Penalty + elif action == 2: # HOLD + if is_near_bottom and self.integrator.current_position_size > 0: + # Good to hold if we have positions at bottom + reward += 0.002 # Small bonus + elif is_near_top and self.integrator.current_position_size == 0: + # Good to hold if we have no positions at top + reward += 0.002 # Small bonus + + # Limit extreme rewards + reward = max(min(reward, 0.5), -0.5) - # Apply extrema-based reward modifications - if len(self.integrator.price_history) > 20: - # Detect local extrema - tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema( - self.integrator.price_history - ) - - # Calculate additional rewards based on extrema - if action == 0 and bottoms_indices and bottoms_indices[-1] > len(self.integrator.price_history) - 5: - # Bonus for buying near bottoms - reward += 0.01 - if self.integrator.session_step % 50 == 0: # Log less frequently - # Display the real market price if available - display_price = real_market_price if real_market_price is not None else current_price - logger.info(f"BUY signal near bottom detected at price {display_price:.2f}! Adding bonus reward.") - - elif action == 1 and tops_indices and tops_indices[-1] > len(self.integrator.price_history) - 5: - # Bonus for selling near tops - reward += 0.01 - if self.integrator.session_step % 50 == 0: # Log less frequently - # Display the real market price if available - display_price = real_market_price if real_market_price is not None else current_price - logger.info(f"SELL signal near top detected at price {display_price:.2f}! Adding bonus reward.") - return reward, pnl # Create a custom environment class factory @@ -758,7 +857,22 @@ def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"): # For SELL actions, close the position with given PnL if action == "SELL": + # Find the most recent BUY position that hasn't been closed + entry_position = None + entry_price = price # Default if no open position found + + for pos in reversed(chart.positions): + if pos.action == "BUY" and pos.is_open: + entry_position = pos + entry_price = pos.entry_price + # Mark this position as closed + pos.close(price, timestamp) + break + + # Close this sell position with the right prices + position.entry_price = entry_price # Use the found entry price position.close(price, timestamp) + # Use realistic PnL values rather than the enormous ones from the model # Cap PnL to reasonable values based on position size and price max_reasonable_pnl = price * amount * 0.05 # Max 5% profit per trade