#!/usr/bin/env python """ Integrated RL Trading with Realtime Visualization This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py) to display the actions taken by the RL agent on the realtime chart. """ import os import sys import logging import asyncio import threading import time from datetime import datetime import signal import numpy as np import torch import json from threading import Thread import pandas as pd from scipy.signal import argrelextrema # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(f'rl_realtime_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'), logging.StreamHandler() ] ) logger = logging.getLogger('rl_realtime') # Add the project root to path if needed project_root = os.path.dirname(os.path.abspath(__file__)) if project_root not in sys.path: sys.path.append(project_root) # Global variables for coordination realtime_chart = None realtime_websocket_task = None running = True def signal_handler(sig, frame): """Handle CTRL+C to gracefully exit training""" global running logger.info("Received interrupt signal. Finishing current epoch and shutting down...") running = False # Register signal handler signal.signal(signal.SIGINT, signal_handler) class ExtremaDetector: """ Detects local extrema (tops and bottoms) in price data """ def __init__(self, window_size=10, order=5): """ Args: window_size (int): Size of the window to look for extrema order (int): How many points on each side to use for comparison """ self.window_size = window_size self.order = order def find_extrema(self, prices): """ Find the local minima and maxima in the price series Args: prices (array-like): Array of price values Returns: tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur """ # Convert to numpy array if needed price_array = np.array(prices) # Find local maxima (tops) local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0] # Find local minima (bottoms) local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0] # Filter out extrema that are too close to the edges max_indices = local_max_indices[local_max_indices >= self.order] max_indices = max_indices[max_indices < len(price_array) - self.order] min_indices = local_min_indices[local_min_indices >= self.order] min_indices = min_indices[min_indices < len(price_array) - self.order] return max_indices, min_indices class RLTrainingIntegrator: """ Integrates RL training with realtime chart visualization. Acts as a bridge between the RL training process and the realtime chart. """ def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent"): self.chart = chart self.symbol = symbol self.model_save_path = model_save_path self.episode_count = 0 self.action_history = [] self.reward_history = [] self.trade_count = 0 self.win_count = 0 # Add session-wide PnL tracking self.session_pnl = 0.0 self.session_trades = 0 self.session_wins = 0 self.session_balance = 100.0 # Start with $100 balance # Track current position state self.in_position = False self.entry_price = None self.entry_time = None # Extrema detector self.extrema_detector = ExtremaDetector(window_size=10, order=5) # Store the agent reference self.agent = None def start_training(self, num_episodes=5000, max_steps=2000): """Start the RL training process with visualization integration""" from NN.train_rl import train_rl, RLTradingEnvironment logger.info(f"Starting RL training with realtime visualization for {self.symbol}") # Define callbacks for the training process def on_action(step, action, price, reward, info): """Callback for each action taken by the agent""" # Only visualize non-hold actions if action != 2: # 0=Buy, 1=Sell, 2=Hold # Convert to string action action_str = "BUY" if action == 0 else "SELL" # Get timestamp - we'll use current time as a proxy timestamp = datetime.now() # Track position state if action == 0 and not self.in_position: # Buy and not already in position self.in_position = True self.entry_price = price self.entry_time = timestamp # Send to chart - visualize buy signal if self.chart and hasattr(self.chart, 'add_nn_signal'): self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward)) elif action == 1 and self.in_position: # Sell and in position (complete trade) self.in_position = False # Calculate profit if we have entry data pnl = None if self.entry_price is not None: # Calculate percentage change pnl_pct = (price - self.entry_price) / self.entry_price # Cap extreme PnL values to more realistic levels (-90% to +100%) pnl_pct = max(min(pnl_pct, 1.0), -0.9) # Apply to current balance trade_amount = self.session_balance * 0.1 # Use 10% of balance per trade trade_profit = trade_amount * pnl_pct self.session_balance += trade_profit # Ensure session balance doesn't go below $1 self.session_balance = max(self.session_balance, 1.0) # For normalized display in charts and logs pnl = pnl_pct # Update session-wide PnL self.session_pnl += pnl self.session_trades += 1 if pnl > 0: self.session_wins += 1 # Log the complete trade on the chart if self.chart: # Show sell signal if hasattr(self.chart, 'add_nn_signal'): self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward)) # Record the trade with PnL if hasattr(self.chart, 'add_trade'): self.chart.add_trade( price=price, timestamp=timestamp, pnl=pnl, amount=0.1, action=action_str, type=action_str # Add explicit type ) # Update trade counts self.trade_count += 1 if pnl is not None and pnl > 0: self.win_count += 1 # Reset entry data self.entry_price = None self.entry_time = None # Track all actions self.action_history.append({ 'step': step, 'action': action_str, 'price': price, 'reward': reward, 'timestamp': timestamp.isoformat() }) else: # Hold action action_str = "HOLD" timestamp = datetime.now() # Update chart trading info if self.chart and hasattr(self.chart, 'update_trading_info'): # Determine current position size (0.1 if in position, 0 if not) position_size = 0.1 if self.in_position else 0.0 self.chart.update_trading_info( signal=action_str, position=position_size, balance=self.session_balance, pnl=self.session_pnl ) # Track reward for all actions (including hold) self.reward_history.append(reward) # Log periodically if len(self.reward_history) % 100 == 0: avg_reward = sum(self.reward_history[-100:]) / 100 logger.info(f"Step {step}: Avg reward (last 100): {avg_reward:.4f}, Actions: {len(self.action_history)}, Trades: {self.trade_count}") def on_episode(episode, reward, info): """Callback for each completed episode""" self.episode_count += 1 # Log episode results logger.info(f"Episode {episode} completed") logger.info(f" Total reward: {reward:.4f}") logger.info(f" PnL: {info['gain']:.4f}") logger.info(f" Win rate: {info['win_rate']:.4f}") logger.info(f" Trades: {info['trades']}") # Log session-wide PnL session_win_rate = self.session_wins / self.session_trades if self.session_trades > 0 else 0 logger.info(f" Session Balance: ${self.session_balance:.2f}") logger.info(f" Session Total PnL: {self.session_pnl:.4f}") logger.info(f" Session Win Rate: {session_win_rate:.4f}") logger.info(f" Session Trades: {self.session_trades}") # Update chart trading info with final episode information if self.chart and hasattr(self.chart, 'update_trading_info'): # Reset position since we're between episodes self.chart.update_trading_info( signal="HOLD", position=0.0, balance=self.session_balance, pnl=self.session_pnl ) # Reset position state for new episode self.in_position = False self.entry_price = None self.entry_time = None # After each episode, perform additional training for local extrema if hasattr(self.agent, 'policy_net') and hasattr(self.agent, 'replay') and episode > 0: self._train_on_extrema(self.agent, info['env']) # Start the actual training with our callbacks self.agent = train_rl( num_episodes=num_episodes, max_steps=max_steps, save_path=self.model_save_path, action_callback=on_action, episode_callback=on_episode, symbol=self.symbol ) logger.info("RL training completed") return self.agent def _train_on_extrema(self, agent, env): """ Perform additional training on local extrema (tops and bottoms) to help the model learn these important patterns faster Args: agent: The DQN agent env: The trading environment """ if not hasattr(env, 'features_1m') or len(env.features_1m) == 0: logger.warning("Environment doesn't have price data for extrema detection") return try: # Extract close prices prices = env.features_1m[:, -1] # Assuming close price is the last column # Find local extrema max_indices, min_indices = self.extrema_detector.find_extrema(prices) if len(max_indices) == 0 or len(min_indices) == 0: logger.warning("No extrema found in the current price data") return logger.info(f"Found {len(max_indices)} tops and {len(min_indices)} bottoms for additional training") # Calculate price changes at extrema to prioritize more significant ones max_price_changes = [] for idx in max_indices: if idx < 5 or idx >= len(prices) - 5: continue # Calculate percentage price rise from previous 5 candles to the peak min_before = min(prices[idx-5:idx]) price_change = (prices[idx] - min_before) / min_before max_price_changes.append((idx, price_change)) min_price_changes = [] for idx in min_indices: if idx < 5 or idx >= len(prices) - 5: continue # Calculate percentage price drop from previous 5 candles to the bottom max_before = max(prices[idx-5:idx]) price_change = (max_before - prices[idx]) / max_before min_price_changes.append((idx, price_change)) # Sort extrema by significance (larger price change is more important) max_price_changes.sort(key=lambda x: x[1], reverse=True) min_price_changes.sort(key=lambda x: x[1], reverse=True) # Take top 10 most significant extrema or all if fewer max_indices = [idx for idx, _ in max_price_changes[:10]] min_indices = [idx for idx, _ in min_price_changes[:10]] # Log the significance of the extrema if max_indices: logger.info(f"Top extrema price changes: {[round(pc*100, 2) for _, pc in max_price_changes[:5]]}%") if min_indices: logger.info(f"Bottom extrema price changes: {[round(pc*100, 2) for _, pc in min_price_changes[:5]]}%") # Collect states, actions, rewards for batch training states = [] actions = [] rewards = [] next_states = [] dones = [] # Process tops (local maxima - should sell) for idx in max_indices: if idx < env.window_size + 2 or idx >= len(prices) - 2: continue # Create states for multiple points approaching the top # This helps the model learn to recognize the pattern leading to the top for offset in range(1, 4): # Look at 1, 2, and 3 candles before the top if idx - offset < env.window_size: continue # State before the peak state_idx = idx - offset env.current_step = state_idx state = env._get_observation() # The next state would be closer to the peak env.current_step = state_idx + 1 next_state = env._get_observation() # Reward increases as we get closer to the peak # Stronger rewards for being right at the peak reward = 1.0 if offset > 1 else 2.0 # Add to memory action = 1 # Sell agent.remember(state, action, reward, next_state, False, is_extrema=True) # Add to batch states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(False) # Process bottoms (local minima - should buy) for idx in min_indices: if idx < env.window_size + 2 or idx >= len(prices) - 2: continue # Create states for multiple points approaching the bottom for offset in range(1, 4): # Look at 1, 2, and 3 candles before the bottom if idx - offset < env.window_size: continue # State before the bottom state_idx = idx - offset env.current_step = state_idx state = env._get_observation() # The next state would be closer to the bottom env.current_step = state_idx + 1 next_state = env._get_observation() # Reward increases as we get closer to the bottom reward = 1.0 if offset > 1 else 2.0 # Add to memory action = 0 # Buy agent.remember(state, action, reward, next_state, False, is_extrema=True) # Add to batch states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(False) # Add some negative examples - don't buy at tops, don't sell at bottoms for idx in max_indices[:5]: # Use a few top peaks if idx < env.window_size + 1 or idx >= len(prices) - 1: continue # State at the peak env.current_step = idx state = env._get_observation() # Next state env.current_step = idx + 1 next_state = env._get_observation() # Strong negative reward for buying at a peak reward = -1.5 # Add negative example of buying at a peak action = 0 # Buy (wrong action) agent.remember(state, action, reward, next_state, False, is_extrema=True) # Add to batch states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(False) for idx in min_indices[:5]: # Use a few bottom troughs if idx < env.window_size + 1 or idx >= len(prices) - 1: continue # State at the bottom env.current_step = idx state = env._get_observation() # Next state env.current_step = idx + 1 next_state = env._get_observation() # Strong negative reward for selling at a bottom reward = -1.5 # Add negative example of selling at a bottom action = 1 # Sell (wrong action) agent.remember(state, action, reward, next_state, False, is_extrema=True) # Add to batch states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(False) # Train on the collected extrema samples if len(states) > 0: logger.info(f"Performing additional training on {len(states)} extrema patterns") loss = agent.train_on_extrema(states, actions, rewards, next_states, dones) logger.info(f"Extrema training loss: {loss:.4f}") # Additional replay passes with extrema samples included for _ in range(5): loss = agent.replay(use_extrema=True) logger.info(f"Mixed replay with extrema - loss: {loss:.4f}") except Exception as e: logger.error(f"Error during extrema training: {str(e)}") import traceback logger.error(traceback.format_exc()) async def start_realtime_chart(symbol="BTC/USDT", port=8050): """ Start the realtime chart display in a separate thread Returns: tuple: (chart, websocket_task) """ from realtime import RealTimeChart try: logger.info(f"Initializing RealTimeChart for {symbol}") # Create the chart with sample data enabled and no-ticks warnings disabled chart = RealTimeChart(symbol, use_sample_data=True, log_no_ticks_warning=False) # Start the WebSocket connection in a separate thread # The _start_websocket_thread method already handles this correctly # Run the Dash server in a separate thread thread = Thread(target=lambda c=chart, p=port: c.run(host='localhost', port=p)) thread.daemon = True thread.start() # Give the server a moment to start await asyncio.sleep(2) logger.info(f"Started realtime chart for {symbol} on port {port}") logger.info(f"You can view the chart at http://localhost:{port}/") # Return the chart and a dummy websocket task (the real one is running in a thread) return chart, asyncio.create_task(asyncio.sleep(0)) except Exception as e: logger.error(f"Error starting realtime chart: {str(e)}") import traceback logger.error(traceback.format_exc()) raise def run_training_thread(chart): """Start the RL training in a separate thread""" integrator = RLTrainingIntegrator(chart) def training_thread_func(): try: # Use a small number of episodes to test termination handling integrator.start_training(num_episodes=100, max_steps=2000) except Exception as e: logger.error(f"Error in training thread: {str(e)}") thread = threading.Thread(target=training_thread_func) thread.daemon = True thread.start() logger.info("Started RL training thread") return thread, integrator def test_signals(chart): """Add test signals to the chart to verify functionality""" from datetime import datetime logger.info("Adding test signals to chart") # Add a test BUY signal chart.add_nn_signal("BUY", datetime.now(), 0.95) # Sleep briefly time.sleep(1) # Add a test SELL signal chart.add_nn_signal("SELL", datetime.now(), 0.85) # Add a test trade if the method exists if hasattr(chart, 'add_trade'): chart.add_trade( price=83000.0, timestamp=datetime.now(), pnl=0.05, action="BUY", type="BUY" # Add explicit type ) else: logger.warning("RealTimeChart has no add_trade method - skipping test trade") async def main(): """Main function that coordinates the realtime chart and RL training""" global realtime_chart, realtime_websocket_task, running logger.info("Starting integrated RL training with realtime visualization") # Start the realtime chart realtime_chart, realtime_websocket_task = await start_realtime_chart() # Wait a bit for the chart to initialize await asyncio.sleep(5) # Test signals first test_signals(realtime_chart) # Start the training in a separate thread training_thread, integrator = run_training_thread(realtime_chart) try: # Keep the main task running until interrupted while running and training_thread.is_alive(): await asyncio.sleep(1) except KeyboardInterrupt: logger.info("Shutting down...") except Exception as e: logger.error(f"Unexpected error: {str(e)}") finally: # Log final PnL summary if hasattr(integrator, 'session_pnl'): session_win_rate = integrator.session_wins / integrator.session_trades if integrator.session_trades > 0 else 0 logger.info("=" * 50) logger.info("FINAL SESSION SUMMARY") logger.info("=" * 50) logger.info(f"Final Session Balance: ${integrator.session_balance:.2f}") logger.info(f"Total Session PnL: {integrator.session_pnl:.4f}") logger.info(f"Total Session Win Rate: {session_win_rate:.4f} ({integrator.session_wins}/{integrator.session_trades})") logger.info(f"Total Session Trades: {integrator.session_trades}") logger.info("=" * 50) # Clean up if realtime_websocket_task: realtime_websocket_task.cancel() try: await realtime_websocket_task except asyncio.CancelledError: pass logger.info("Application terminated") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: logger.info("Application terminated by user")