gogo2/train_rl_with_realtime.py
Dobromir Popov a78906a888 improvements
2025-04-01 18:30:34 +03:00

629 lines
25 KiB
Python

#!/usr/bin/env python
"""
Integrated RL Trading with Realtime Visualization
This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py)
to display the actions taken by the RL agent on the realtime chart.
"""
import os
import sys
import logging
import asyncio
import threading
import time
from datetime import datetime
import signal
import numpy as np
import torch
import json
from threading import Thread
import pandas as pd
import argparse
from scipy.signal import argrelextrema
# Parse command line arguments
parser = argparse.ArgumentParser(description='Integrated RL Trading with Realtime Visualization')
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
parser.add_argument('--no-train', action='store_true', help='Skip training, just visualize')
parser.add_argument('--visualize-only', action='store_true', help='Only run the visualization')
parser.add_argument('--manual-trades', action='store_true', help='Enable manual trading mode')
parser.add_argument('--log-file', type=str, help='Specify custom log filename')
args = parser.parse_args()
# Configure logging
log_filename = args.log_file or f'rl_realtime_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
]
)
logger = logging.getLogger('rl_realtime')
# Add the project root to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Global variables for coordination
realtime_chart = None
realtime_websocket_task = None
running = True
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and shutting down...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
class ExtremaDetector:
"""
Detects local extrema (tops and bottoms) in price data
"""
def __init__(self, window_size=10, order=5):
"""
Args:
window_size (int): Size of the window to look for extrema
order (int): How many points on each side to use for comparison
"""
self.window_size = window_size
self.order = order
def find_extrema(self, prices):
"""
Find the local minima and maxima in the price series
Args:
prices (array-like): Array of price values
Returns:
tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur
"""
# Convert to numpy array if needed
price_array = np.array(prices)
# Find local maxima (tops)
local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0]
# Find local minima (bottoms)
local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0]
# Filter out extrema that are too close to the edges
max_indices = local_max_indices[local_max_indices >= self.order]
max_indices = max_indices[max_indices < len(price_array) - self.order]
min_indices = local_min_indices[local_min_indices >= self.order]
min_indices = min_indices[min_indices < len(price_array) - self.order]
return max_indices, min_indices
class RLTrainingIntegrator:
"""
Integrates RL training with realtime chart visualization.
Acts as a bridge between the RL training process and the realtime chart.
"""
def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent"):
self.chart = chart
self.symbol = symbol
self.model_save_path = model_save_path
self.episode_count = 0
self.action_history = []
self.reward_history = []
self.trade_count = 0
self.win_count = 0
# Add session-wide PnL tracking
self.session_pnl = 0.0
self.session_trades = 0
self.session_wins = 0
self.session_balance = 100.0 # Start with $100 balance
# Track current position state
self.in_position = False
self.entry_price = None
self.entry_time = None
# Extrema detector
self.extrema_detector = ExtremaDetector(window_size=10, order=5)
# Store the agent reference
self.agent = None
def start_training(self, num_episodes=5000, max_steps=2000):
"""Start the RL training process with visualization integration"""
from NN.train_rl import train_rl, RLTradingEnvironment
logger.info(f"Starting RL training with realtime visualization for {self.symbol}")
# Define callbacks for the training process
def on_action(step, action, price, reward, info):
"""Called after each action in the episode"""
# Log the action
action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
action_price = price
# Update session PnL and balance
self.session_step += 1
self.session_pnl += reward
# Increase balance based on reward
self.session_balance += reward
# Handle win/loss tracking
if reward != 0: # If this was a trade with P&L
self.session_trades += 1
if reward > 0:
self.session_wins += 1
# Only log a subset of actions to avoid excessive output
if step % 100 == 0 or step < 10 or self.session_step % 100 == 0:
logger.info(f"Step {step}, Action: {action_str}, Price: {action_price:.2f}, Reward: {reward:.4f}, PnL: {self.session_pnl:.4f}, Balance: ${self.session_balance:.2f}")
# Update the chart with the action - note positions are currently tracked in env
if action == 0: # BUY
# Only add to chart for visualization if we have a chart
if self.chart and hasattr(self.chart, "add_trade"):
# Adding a BUY trade
try:
self.chart.add_trade(
price=action_price,
timestamp=datetime.now(),
amount=0.1, # Standard amount
pnl=reward,
action="BUY"
)
self.chart.last_action = "BUY"
except Exception as e:
logger.error(f"Failed to add BUY trade to chart: {str(e)}")
elif action == 1: # SELL
# Only add to chart for visualization if we have a chart
if self.chart and hasattr(self.chart, "add_trade"):
# Adding a SELL trade
try:
self.chart.add_trade(
price=action_price,
timestamp=datetime.now(),
amount=0.1, # Standard amount
pnl=reward,
action="SELL"
)
self.chart.last_action = "SELL"
except Exception as e:
logger.error(f"Failed to add SELL trade to chart: {str(e)}")
# Update the trading info display on chart
if self.chart and hasattr(self.chart, "update_trading_info"):
try:
# Update the trading info panel with latest data
self.chart.update_trading_info(
signal=action_str,
position=0.1 if action == 0 else 0,
balance=self.session_balance,
pnl=self.session_pnl
)
except Exception as e:
logger.warning(f"Failed to update trading info: {str(e)}")
# Check for manual termination
if self.stop_event.is_set():
return False # Signal to stop episode
return True # Continue episode
def on_episode(episode, reward, info):
"""Callback for each completed episode"""
self.episode_count += 1
# Log episode results
logger.info(f"Episode {episode} completed")
logger.info(f" Total reward: {reward:.4f}")
logger.info(f" PnL: {info['gain']:.4f}")
logger.info(f" Win rate: {info['win_rate']:.4f}")
logger.info(f" Trades: {info['trades']}")
# Log session-wide PnL
session_win_rate = self.session_wins / self.session_trades if self.session_trades > 0 else 0
logger.info(f" Session Balance: ${self.session_balance:.2f}")
logger.info(f" Session Total PnL: {self.session_pnl:.4f}")
logger.info(f" Session Win Rate: {session_win_rate:.4f}")
logger.info(f" Session Trades: {self.session_trades}")
# Update chart trading info with final episode information
if self.chart and hasattr(self.chart, 'update_trading_info'):
# Reset position since we're between episodes
self.chart.update_trading_info(
signal="HOLD",
position=0.0,
balance=self.session_balance,
pnl=self.session_pnl
)
# Reset position state for new episode
self.in_position = False
self.entry_price = None
self.entry_time = None
# After each episode, perform additional training for local extrema
if hasattr(self.agent, 'policy_net') and hasattr(self.agent, 'replay') and episode > 0:
self._train_on_extrema(self.agent, info['env'])
# Start the actual training with our callbacks
self.agent = train_rl(
num_episodes=num_episodes,
max_steps=max_steps,
save_path=self.model_save_path,
action_callback=on_action,
episode_callback=on_episode,
symbol=self.symbol
)
logger.info("RL training completed")
return self.agent
def _train_on_extrema(self, agent, env):
"""
Perform additional training on local extrema (tops and bottoms)
to help the model learn these important patterns faster
Args:
agent: The DQN agent
env: The trading environment
"""
if not hasattr(env, 'features_1m') or len(env.features_1m) == 0:
logger.warning("Environment doesn't have price data for extrema detection")
return
try:
# Extract close prices
prices = env.features_1m[:, -1] # Assuming close price is the last column
# Find local extrema
max_indices, min_indices = self.extrema_detector.find_extrema(prices)
if len(max_indices) == 0 or len(min_indices) == 0:
logger.warning("No extrema found in the current price data")
return
logger.info(f"Found {len(max_indices)} tops and {len(min_indices)} bottoms for additional training")
# Calculate price changes at extrema to prioritize more significant ones
max_price_changes = []
for idx in max_indices:
if idx < 5 or idx >= len(prices) - 5:
continue
# Calculate percentage price rise from previous 5 candles to the peak
min_before = min(prices[idx-5:idx])
price_change = (prices[idx] - min_before) / min_before
max_price_changes.append((idx, price_change))
min_price_changes = []
for idx in min_indices:
if idx < 5 or idx >= len(prices) - 5:
continue
# Calculate percentage price drop from previous 5 candles to the bottom
max_before = max(prices[idx-5:idx])
price_change = (max_before - prices[idx]) / max_before
min_price_changes.append((idx, price_change))
# Sort extrema by significance (larger price change is more important)
max_price_changes.sort(key=lambda x: x[1], reverse=True)
min_price_changes.sort(key=lambda x: x[1], reverse=True)
# Take top 10 most significant extrema or all if fewer
max_indices = [idx for idx, _ in max_price_changes[:10]]
min_indices = [idx for idx, _ in min_price_changes[:10]]
# Log the significance of the extrema
if max_indices:
logger.info(f"Top extrema price changes: {[round(pc*100, 2) for _, pc in max_price_changes[:5]]}%")
if min_indices:
logger.info(f"Bottom extrema price changes: {[round(pc*100, 2) for _, pc in min_price_changes[:5]]}%")
# Collect states, actions, rewards for batch training
states = []
actions = []
rewards = []
next_states = []
dones = []
# Process tops (local maxima - should sell)
for idx in max_indices:
if idx < env.window_size + 2 or idx >= len(prices) - 2:
continue
# Create states for multiple points approaching the top
# This helps the model learn to recognize the pattern leading to the top
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the top
if idx - offset < env.window_size:
continue
# State before the peak
state_idx = idx - offset
env.current_step = state_idx
state = env._get_observation()
# The next state would be closer to the peak
env.current_step = state_idx + 1
next_state = env._get_observation()
# Reward increases as we get closer to the peak
# Stronger rewards for being right at the peak
reward = 1.0 if offset > 1 else 2.0
# Add to memory
action = 1 # Sell
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Process bottoms (local minima - should buy)
for idx in min_indices:
if idx < env.window_size + 2 or idx >= len(prices) - 2:
continue
# Create states for multiple points approaching the bottom
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the bottom
if idx - offset < env.window_size:
continue
# State before the bottom
state_idx = idx - offset
env.current_step = state_idx
state = env._get_observation()
# The next state would be closer to the bottom
env.current_step = state_idx + 1
next_state = env._get_observation()
# Reward increases as we get closer to the bottom
reward = 1.0 if offset > 1 else 2.0
# Add to memory
action = 0 # Buy
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Add some negative examples - don't buy at tops, don't sell at bottoms
for idx in max_indices[:5]: # Use a few top peaks
if idx < env.window_size + 1 or idx >= len(prices) - 1:
continue
# State at the peak
env.current_step = idx
state = env._get_observation()
# Next state
env.current_step = idx + 1
next_state = env._get_observation()
# Strong negative reward for buying at a peak
reward = -1.5
# Add negative example of buying at a peak
action = 0 # Buy (wrong action)
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
for idx in min_indices[:5]: # Use a few bottom troughs
if idx < env.window_size + 1 or idx >= len(prices) - 1:
continue
# State at the bottom
env.current_step = idx
state = env._get_observation()
# Next state
env.current_step = idx + 1
next_state = env._get_observation()
# Strong negative reward for selling at a bottom
reward = -1.5
# Add negative example of selling at a bottom
action = 1 # Sell (wrong action)
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Train on the collected extrema samples
if len(states) > 0:
logger.info(f"Performing additional training on {len(states)} extrema patterns")
loss = agent.train_on_extrema(states, actions, rewards, next_states, dones)
logger.info(f"Extrema training loss: {loss:.4f}")
# Additional replay passes with extrema samples included
for _ in range(5):
loss = agent.replay(use_extrema=True)
logger.info(f"Mixed replay with extrema - loss: {loss:.4f}")
except Exception as e:
logger.error(f"Error during extrema training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
async def start_realtime_chart(symbol="BTC/USDT", port=8050):
"""Start the realtime trading chart in a separate thread
Returns:
tuple: (RealTimeChart instance, websocket task)
"""
from realtime import RealTimeChart
try:
logger.info(f"Initializing RealTimeChart for {symbol}")
# Create the chart with the new parameter interface
chart = RealTimeChart(symbol, data_path=None, historical_data=None)
# Give the server a moment to start (the app is started automatically in __init__ now)
await asyncio.sleep(2)
logger.info(f"Started realtime chart for {symbol} on port {port}")
logger.info(f"You can view the chart at http://localhost:{port}/")
# Return the chart and a dummy websocket task (the real one is running in a thread)
return chart, asyncio.create_task(asyncio.sleep(0))
except Exception as e:
logger.error(f"Error starting realtime chart: {str(e)}")
import traceback
logger.error(traceback.format_exc())
raise
def run_training_thread(chart, num_episodes=5000, skip_training=False):
"""Start the RL training in a separate thread"""
integrator = RLTrainingIntegrator(chart)
def training_thread_func():
try:
# Create stop event
integrator.stop_event = threading.Event()
# Initialize session tracking
integrator.session_step = 0
if skip_training:
logger.info("Skipping training as requested (--no-train flag)")
# Just sleep for a bit to keep the thread alive
time.sleep(10)
else:
# Use a small number of episodes to test termination handling
integrator.start_training(num_episodes=num_episodes, max_steps=2000)
except Exception as e:
logger.error(f"Error in training thread: {str(e)}")
import traceback
logger.error(traceback.format_exc())
thread = threading.Thread(target=training_thread_func)
thread.daemon = True
thread.start()
logger.info("Started RL training thread")
return thread, integrator
def test_signals(chart):
"""Add test signals and trades to the chart to verify functionality"""
from datetime import datetime
logger.info("Adding test trades to chart")
# Add test trades
if hasattr(chart, 'add_trade'):
# Add a BUY trade
chart.add_trade(
price=83000.0,
timestamp=datetime.now(),
amount=0.1,
pnl=0.05,
action="BUY"
)
# Wait briefly
time.sleep(1)
# Add a SELL trade
chart.add_trade(
price=83050.0,
timestamp=datetime.now(),
amount=0.1,
pnl=0.2,
action="SELL"
)
logger.info("Test trades added successfully")
else:
logger.warning("RealTimeChart has no add_trade method - skipping test trades")
async def main():
"""Main function that coordinates the realtime chart and RL training"""
global realtime_chart, realtime_websocket_task, running
logger.info("Starting integrated RL training with realtime visualization")
logger.info(f"Using log file: {log_filename}")
# Start the realtime chart
realtime_chart, realtime_websocket_task = await start_realtime_chart()
# Wait a bit for the chart to initialize
await asyncio.sleep(5)
# Test signals first
test_signals(realtime_chart)
# If visualize-only is set, don't start the training thread
if not args.visualize_only or not args.no_train:
# Start the training in a separate thread
num_episodes = args.episodes if not args.no_train else 1
training_thread, integrator = run_training_thread(realtime_chart, num_episodes=num_episodes,
skip_training=args.no_train)
else:
# Create a dummy integrator for the final stats
integrator = RLTrainingIntegrator(realtime_chart)
integrator.session_pnl = 0.0
integrator.session_trades = 0
integrator.session_wins = 0
integrator.session_balance = 100.0
training_thread = None
try:
# Keep the main task running until interrupted
while running and (training_thread is None or training_thread.is_alive()):
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down...")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
finally:
# Log final PnL summary
if hasattr(integrator, 'session_pnl'):
session_win_rate = integrator.session_wins / integrator.session_trades if integrator.session_trades > 0 else 0
logger.info("=" * 50)
logger.info("FINAL SESSION SUMMARY")
logger.info("=" * 50)
logger.info(f"Final Session Balance: ${integrator.session_balance:.2f}")
logger.info(f"Total Session PnL: {integrator.session_pnl:.4f}")
logger.info(f"Total Session Win Rate: {session_win_rate:.4f} ({integrator.session_wins}/{integrator.session_trades})")
logger.info(f"Total Session Trades: {integrator.session_trades}")
logger.info("=" * 50)
# Clean up
if realtime_websocket_task:
realtime_websocket_task.cancel()
try:
await realtime_websocket_task
except asyncio.CancelledError:
pass
logger.info("Application terminated")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application terminated by user")