gogo2/train_rl_with_realtime.py
2025-04-01 15:10:50 +03:00

627 lines
25 KiB
Python

#!/usr/bin/env python
"""
Integrated RL Trading with Realtime Visualization
This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py)
to display the actions taken by the RL agent on the realtime chart.
"""
import os
import sys
import logging
import asyncio
import threading
import time
from datetime import datetime
import signal
import numpy as np
import torch
import json
from threading import Thread
import pandas as pd
from scipy.signal import argrelextrema
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'rl_realtime_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger('rl_realtime')
# Add the project root to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Global variables for coordination
realtime_chart = None
realtime_websocket_task = None
running = True
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and shutting down...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
class ExtremaDetector:
"""
Detects local extrema (tops and bottoms) in price data
"""
def __init__(self, window_size=10, order=5):
"""
Args:
window_size (int): Size of the window to look for extrema
order (int): How many points on each side to use for comparison
"""
self.window_size = window_size
self.order = order
def find_extrema(self, prices):
"""
Find the local minima and maxima in the price series
Args:
prices (array-like): Array of price values
Returns:
tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur
"""
# Convert to numpy array if needed
price_array = np.array(prices)
# Find local maxima (tops)
local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0]
# Find local minima (bottoms)
local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0]
# Filter out extrema that are too close to the edges
max_indices = local_max_indices[local_max_indices >= self.order]
max_indices = max_indices[max_indices < len(price_array) - self.order]
min_indices = local_min_indices[local_min_indices >= self.order]
min_indices = min_indices[min_indices < len(price_array) - self.order]
return max_indices, min_indices
class RLTrainingIntegrator:
"""
Integrates RL training with realtime chart visualization.
Acts as a bridge between the RL training process and the realtime chart.
"""
def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent"):
self.chart = chart
self.symbol = symbol
self.model_save_path = model_save_path
self.episode_count = 0
self.action_history = []
self.reward_history = []
self.trade_count = 0
self.win_count = 0
# Add session-wide PnL tracking
self.session_pnl = 0.0
self.session_trades = 0
self.session_wins = 0
self.session_balance = 100.0 # Start with $100 balance
# Track current position state
self.in_position = False
self.entry_price = None
self.entry_time = None
# Extrema detector
self.extrema_detector = ExtremaDetector(window_size=10, order=5)
# Store the agent reference
self.agent = None
def start_training(self, num_episodes=5000, max_steps=2000):
"""Start the RL training process with visualization integration"""
from NN.train_rl import train_rl, RLTradingEnvironment
logger.info(f"Starting RL training with realtime visualization for {self.symbol}")
# Define callbacks for the training process
def on_action(step, action, price, reward, info):
"""Callback for each action taken by the agent"""
# Only visualize non-hold actions
if action != 2: # 0=Buy, 1=Sell, 2=Hold
# Convert to string action
action_str = "BUY" if action == 0 else "SELL"
# Get timestamp - we'll use current time as a proxy
timestamp = datetime.now()
# Track position state
if action == 0 and not self.in_position: # Buy and not already in position
self.in_position = True
self.entry_price = price
self.entry_time = timestamp
# Send to chart - visualize buy signal
if self.chart and hasattr(self.chart, 'add_nn_signal'):
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
elif action == 1 and self.in_position: # Sell and in position (complete trade)
self.in_position = False
# Calculate profit if we have entry data
pnl = None
if self.entry_price is not None:
# Calculate percentage change
pnl_pct = (price - self.entry_price) / self.entry_price
# Cap extreme PnL values to more realistic levels (-90% to +100%)
pnl_pct = max(min(pnl_pct, 1.0), -0.9)
# Apply to current balance
trade_amount = self.session_balance * 0.1 # Use 10% of balance per trade
trade_profit = trade_amount * pnl_pct
self.session_balance += trade_profit
# Ensure session balance doesn't go below $1
self.session_balance = max(self.session_balance, 1.0)
# For normalized display in charts and logs
pnl = pnl_pct
# Update session-wide PnL
self.session_pnl += pnl
self.session_trades += 1
if pnl > 0:
self.session_wins += 1
# Log the complete trade on the chart
if self.chart:
# Show sell signal
if hasattr(self.chart, 'add_nn_signal'):
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
# Record the trade with PnL
if hasattr(self.chart, 'add_trade'):
self.chart.add_trade(
price=price,
timestamp=timestamp,
pnl=pnl,
amount=0.1,
action=action_str,
type=action_str # Add explicit type
)
# Update trade counts
self.trade_count += 1
if pnl is not None and pnl > 0:
self.win_count += 1
# Reset entry data
self.entry_price = None
self.entry_time = None
# Track all actions
self.action_history.append({
'step': step,
'action': action_str,
'price': price,
'reward': reward,
'timestamp': timestamp.isoformat()
})
else:
# Hold action
action_str = "HOLD"
timestamp = datetime.now()
# Update chart trading info
if self.chart and hasattr(self.chart, 'update_trading_info'):
# Determine current position size (0.1 if in position, 0 if not)
position_size = 0.1 if self.in_position else 0.0
self.chart.update_trading_info(
signal=action_str,
position=position_size,
balance=self.session_balance,
pnl=self.session_pnl
)
# Track reward for all actions (including hold)
self.reward_history.append(reward)
# Log periodically
if len(self.reward_history) % 100 == 0:
avg_reward = sum(self.reward_history[-100:]) / 100
logger.info(f"Step {step}: Avg reward (last 100): {avg_reward:.4f}, Actions: {len(self.action_history)}, Trades: {self.trade_count}")
def on_episode(episode, reward, info):
"""Callback for each completed episode"""
self.episode_count += 1
# Log episode results
logger.info(f"Episode {episode} completed")
logger.info(f" Total reward: {reward:.4f}")
logger.info(f" PnL: {info['gain']:.4f}")
logger.info(f" Win rate: {info['win_rate']:.4f}")
logger.info(f" Trades: {info['trades']}")
# Log session-wide PnL
session_win_rate = self.session_wins / self.session_trades if self.session_trades > 0 else 0
logger.info(f" Session Balance: ${self.session_balance:.2f}")
logger.info(f" Session Total PnL: {self.session_pnl:.4f}")
logger.info(f" Session Win Rate: {session_win_rate:.4f}")
logger.info(f" Session Trades: {self.session_trades}")
# Update chart trading info with final episode information
if self.chart and hasattr(self.chart, 'update_trading_info'):
# Reset position since we're between episodes
self.chart.update_trading_info(
signal="HOLD",
position=0.0,
balance=self.session_balance,
pnl=self.session_pnl
)
# Reset position state for new episode
self.in_position = False
self.entry_price = None
self.entry_time = None
# After each episode, perform additional training for local extrema
if hasattr(self.agent, 'policy_net') and hasattr(self.agent, 'replay') and episode > 0:
self._train_on_extrema(self.agent, info['env'])
# Start the actual training with our callbacks
self.agent = train_rl(
num_episodes=num_episodes,
max_steps=max_steps,
save_path=self.model_save_path,
action_callback=on_action,
episode_callback=on_episode,
symbol=self.symbol
)
logger.info("RL training completed")
return self.agent
def _train_on_extrema(self, agent, env):
"""
Perform additional training on local extrema (tops and bottoms)
to help the model learn these important patterns faster
Args:
agent: The DQN agent
env: The trading environment
"""
if not hasattr(env, 'features_1m') or len(env.features_1m) == 0:
logger.warning("Environment doesn't have price data for extrema detection")
return
try:
# Extract close prices
prices = env.features_1m[:, -1] # Assuming close price is the last column
# Find local extrema
max_indices, min_indices = self.extrema_detector.find_extrema(prices)
if len(max_indices) == 0 or len(min_indices) == 0:
logger.warning("No extrema found in the current price data")
return
logger.info(f"Found {len(max_indices)} tops and {len(min_indices)} bottoms for additional training")
# Calculate price changes at extrema to prioritize more significant ones
max_price_changes = []
for idx in max_indices:
if idx < 5 or idx >= len(prices) - 5:
continue
# Calculate percentage price rise from previous 5 candles to the peak
min_before = min(prices[idx-5:idx])
price_change = (prices[idx] - min_before) / min_before
max_price_changes.append((idx, price_change))
min_price_changes = []
for idx in min_indices:
if idx < 5 or idx >= len(prices) - 5:
continue
# Calculate percentage price drop from previous 5 candles to the bottom
max_before = max(prices[idx-5:idx])
price_change = (max_before - prices[idx]) / max_before
min_price_changes.append((idx, price_change))
# Sort extrema by significance (larger price change is more important)
max_price_changes.sort(key=lambda x: x[1], reverse=True)
min_price_changes.sort(key=lambda x: x[1], reverse=True)
# Take top 10 most significant extrema or all if fewer
max_indices = [idx for idx, _ in max_price_changes[:10]]
min_indices = [idx for idx, _ in min_price_changes[:10]]
# Log the significance of the extrema
if max_indices:
logger.info(f"Top extrema price changes: {[round(pc*100, 2) for _, pc in max_price_changes[:5]]}%")
if min_indices:
logger.info(f"Bottom extrema price changes: {[round(pc*100, 2) for _, pc in min_price_changes[:5]]}%")
# Collect states, actions, rewards for batch training
states = []
actions = []
rewards = []
next_states = []
dones = []
# Process tops (local maxima - should sell)
for idx in max_indices:
if idx < env.window_size + 2 or idx >= len(prices) - 2:
continue
# Create states for multiple points approaching the top
# This helps the model learn to recognize the pattern leading to the top
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the top
if idx - offset < env.window_size:
continue
# State before the peak
state_idx = idx - offset
env.current_step = state_idx
state = env._get_observation()
# The next state would be closer to the peak
env.current_step = state_idx + 1
next_state = env._get_observation()
# Reward increases as we get closer to the peak
# Stronger rewards for being right at the peak
reward = 1.0 if offset > 1 else 2.0
# Add to memory
action = 1 # Sell
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Process bottoms (local minima - should buy)
for idx in min_indices:
if idx < env.window_size + 2 or idx >= len(prices) - 2:
continue
# Create states for multiple points approaching the bottom
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the bottom
if idx - offset < env.window_size:
continue
# State before the bottom
state_idx = idx - offset
env.current_step = state_idx
state = env._get_observation()
# The next state would be closer to the bottom
env.current_step = state_idx + 1
next_state = env._get_observation()
# Reward increases as we get closer to the bottom
reward = 1.0 if offset > 1 else 2.0
# Add to memory
action = 0 # Buy
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Add some negative examples - don't buy at tops, don't sell at bottoms
for idx in max_indices[:5]: # Use a few top peaks
if idx < env.window_size + 1 or idx >= len(prices) - 1:
continue
# State at the peak
env.current_step = idx
state = env._get_observation()
# Next state
env.current_step = idx + 1
next_state = env._get_observation()
# Strong negative reward for buying at a peak
reward = -1.5
# Add negative example of buying at a peak
action = 0 # Buy (wrong action)
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
for idx in min_indices[:5]: # Use a few bottom troughs
if idx < env.window_size + 1 or idx >= len(prices) - 1:
continue
# State at the bottom
env.current_step = idx
state = env._get_observation()
# Next state
env.current_step = idx + 1
next_state = env._get_observation()
# Strong negative reward for selling at a bottom
reward = -1.5
# Add negative example of selling at a bottom
action = 1 # Sell (wrong action)
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Train on the collected extrema samples
if len(states) > 0:
logger.info(f"Performing additional training on {len(states)} extrema patterns")
loss = agent.train_on_extrema(states, actions, rewards, next_states, dones)
logger.info(f"Extrema training loss: {loss:.4f}")
# Additional replay passes with extrema samples included
for _ in range(5):
loss = agent.replay(use_extrema=True)
logger.info(f"Mixed replay with extrema - loss: {loss:.4f}")
except Exception as e:
logger.error(f"Error during extrema training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
async def start_realtime_chart(symbol="BTC/USDT", port=8050):
"""
Start the realtime chart display in a separate thread
Returns:
tuple: (chart, websocket_task)
"""
from realtime import RealTimeChart
try:
logger.info(f"Initializing RealTimeChart for {symbol}")
# Create the chart with sample data enabled and no-ticks warnings disabled
chart = RealTimeChart(symbol, use_sample_data=True, log_no_ticks_warning=False)
# Start the WebSocket connection in a separate thread
# The _start_websocket_thread method already handles this correctly
# Run the Dash server in a separate thread
thread = Thread(target=lambda c=chart, p=port: c.run(host='localhost', port=p))
thread.daemon = True
thread.start()
# Give the server a moment to start
await asyncio.sleep(2)
logger.info(f"Started realtime chart for {symbol} on port {port}")
logger.info(f"You can view the chart at http://localhost:{port}/")
# Return the chart and a dummy websocket task (the real one is running in a thread)
return chart, asyncio.create_task(asyncio.sleep(0))
except Exception as e:
logger.error(f"Error starting realtime chart: {str(e)}")
import traceback
logger.error(traceback.format_exc())
raise
def run_training_thread(chart):
"""Start the RL training in a separate thread"""
integrator = RLTrainingIntegrator(chart)
def training_thread_func():
try:
# Use a small number of episodes to test termination handling
integrator.start_training(num_episodes=2, max_steps=500)
except Exception as e:
logger.error(f"Error in training thread: {str(e)}")
thread = threading.Thread(target=training_thread_func)
thread.daemon = True
thread.start()
logger.info("Started RL training thread")
return thread, integrator
def test_signals(chart):
"""Add test signals to the chart to verify functionality"""
from datetime import datetime
logger.info("Adding test signals to chart")
# Add a test BUY signal
chart.add_nn_signal("BUY", datetime.now(), 0.95)
# Sleep briefly
time.sleep(1)
# Add a test SELL signal
chart.add_nn_signal("SELL", datetime.now(), 0.85)
# Add a test trade if the method exists
if hasattr(chart, 'add_trade'):
chart.add_trade(
price=83000.0,
timestamp=datetime.now(),
pnl=0.05,
action="BUY",
type="BUY" # Add explicit type
)
else:
logger.warning("RealTimeChart has no add_trade method - skipping test trade")
async def main():
"""Main function that coordinates the realtime chart and RL training"""
global realtime_chart, realtime_websocket_task, running
logger.info("Starting integrated RL training with realtime visualization")
# Start the realtime chart
realtime_chart, realtime_websocket_task = await start_realtime_chart()
# Wait a bit for the chart to initialize
await asyncio.sleep(5)
# Test signals first
test_signals(realtime_chart)
# Start the training in a separate thread
training_thread, integrator = run_training_thread(realtime_chart)
try:
# Keep the main task running until interrupted
while running and training_thread.is_alive():
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down...")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
finally:
# Log final PnL summary
if hasattr(integrator, 'session_pnl'):
session_win_rate = integrator.session_wins / integrator.session_trades if integrator.session_trades > 0 else 0
logger.info("=" * 50)
logger.info("FINAL SESSION SUMMARY")
logger.info("=" * 50)
logger.info(f"Final Session Balance: ${integrator.session_balance:.2f}")
logger.info(f"Total Session PnL: {integrator.session_pnl:.4f}")
logger.info(f"Total Session Win Rate: {session_win_rate:.4f} ({integrator.session_wins}/{integrator.session_trades})")
logger.info(f"Total Session Trades: {integrator.session_trades}")
logger.info("=" * 50)
# Clean up
if realtime_websocket_task:
realtime_websocket_task.cancel()
try:
await realtime_websocket_task
except asyncio.CancelledError:
pass
logger.info("Application terminated")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application terminated by user")