627 lines
25 KiB
Python
627 lines
25 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Integrated RL Trading with Realtime Visualization
|
|
|
|
This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py)
|
|
to display the actions taken by the RL agent on the realtime chart.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
import signal
|
|
import numpy as np
|
|
import torch
|
|
import json
|
|
from threading import Thread
|
|
import pandas as pd
|
|
from scipy.signal import argrelextrema
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(f'rl_realtime_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger('rl_realtime')
|
|
|
|
# Add the project root to path if needed
|
|
project_root = os.path.dirname(os.path.abspath(__file__))
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
# Global variables for coordination
|
|
realtime_chart = None
|
|
realtime_websocket_task = None
|
|
running = True
|
|
|
|
def signal_handler(sig, frame):
|
|
"""Handle CTRL+C to gracefully exit training"""
|
|
global running
|
|
logger.info("Received interrupt signal. Finishing current epoch and shutting down...")
|
|
running = False
|
|
|
|
# Register signal handler
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
class ExtremaDetector:
|
|
"""
|
|
Detects local extrema (tops and bottoms) in price data
|
|
"""
|
|
def __init__(self, window_size=10, order=5):
|
|
"""
|
|
Args:
|
|
window_size (int): Size of the window to look for extrema
|
|
order (int): How many points on each side to use for comparison
|
|
"""
|
|
self.window_size = window_size
|
|
self.order = order
|
|
|
|
def find_extrema(self, prices):
|
|
"""
|
|
Find the local minima and maxima in the price series
|
|
|
|
Args:
|
|
prices (array-like): Array of price values
|
|
|
|
Returns:
|
|
tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur
|
|
"""
|
|
# Convert to numpy array if needed
|
|
price_array = np.array(prices)
|
|
|
|
# Find local maxima (tops)
|
|
local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0]
|
|
|
|
# Find local minima (bottoms)
|
|
local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0]
|
|
|
|
# Filter out extrema that are too close to the edges
|
|
max_indices = local_max_indices[local_max_indices >= self.order]
|
|
max_indices = max_indices[max_indices < len(price_array) - self.order]
|
|
|
|
min_indices = local_min_indices[local_min_indices >= self.order]
|
|
min_indices = min_indices[min_indices < len(price_array) - self.order]
|
|
|
|
return max_indices, min_indices
|
|
|
|
class RLTrainingIntegrator:
|
|
"""
|
|
Integrates RL training with realtime chart visualization.
|
|
Acts as a bridge between the RL training process and the realtime chart.
|
|
"""
|
|
def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent"):
|
|
self.chart = chart
|
|
self.symbol = symbol
|
|
self.model_save_path = model_save_path
|
|
self.episode_count = 0
|
|
self.action_history = []
|
|
self.reward_history = []
|
|
self.trade_count = 0
|
|
self.win_count = 0
|
|
|
|
# Add session-wide PnL tracking
|
|
self.session_pnl = 0.0
|
|
self.session_trades = 0
|
|
self.session_wins = 0
|
|
self.session_balance = 100.0 # Start with $100 balance
|
|
|
|
# Track current position state
|
|
self.in_position = False
|
|
self.entry_price = None
|
|
self.entry_time = None
|
|
|
|
# Extrema detector
|
|
self.extrema_detector = ExtremaDetector(window_size=10, order=5)
|
|
|
|
# Store the agent reference
|
|
self.agent = None
|
|
|
|
def start_training(self, num_episodes=5000, max_steps=2000):
|
|
"""Start the RL training process with visualization integration"""
|
|
from NN.train_rl import train_rl, RLTradingEnvironment
|
|
|
|
logger.info(f"Starting RL training with realtime visualization for {self.symbol}")
|
|
|
|
# Define callbacks for the training process
|
|
def on_action(step, action, price, reward, info):
|
|
"""Callback for each action taken by the agent"""
|
|
# Only visualize non-hold actions
|
|
if action != 2: # 0=Buy, 1=Sell, 2=Hold
|
|
# Convert to string action
|
|
action_str = "BUY" if action == 0 else "SELL"
|
|
|
|
# Get timestamp - we'll use current time as a proxy
|
|
timestamp = datetime.now()
|
|
|
|
# Track position state
|
|
if action == 0 and not self.in_position: # Buy and not already in position
|
|
self.in_position = True
|
|
self.entry_price = price
|
|
self.entry_time = timestamp
|
|
|
|
# Send to chart - visualize buy signal
|
|
if self.chart and hasattr(self.chart, 'add_nn_signal'):
|
|
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
|
|
|
|
elif action == 1 and self.in_position: # Sell and in position (complete trade)
|
|
self.in_position = False
|
|
|
|
# Calculate profit if we have entry data
|
|
pnl = None
|
|
if self.entry_price is not None:
|
|
# Calculate percentage change
|
|
pnl_pct = (price - self.entry_price) / self.entry_price
|
|
|
|
# Cap extreme PnL values to more realistic levels (-90% to +100%)
|
|
pnl_pct = max(min(pnl_pct, 1.0), -0.9)
|
|
|
|
# Apply to current balance
|
|
trade_amount = self.session_balance * 0.1 # Use 10% of balance per trade
|
|
trade_profit = trade_amount * pnl_pct
|
|
self.session_balance += trade_profit
|
|
|
|
# Ensure session balance doesn't go below $1
|
|
self.session_balance = max(self.session_balance, 1.0)
|
|
|
|
# For normalized display in charts and logs
|
|
pnl = pnl_pct
|
|
|
|
# Update session-wide PnL
|
|
self.session_pnl += pnl
|
|
self.session_trades += 1
|
|
if pnl > 0:
|
|
self.session_wins += 1
|
|
|
|
# Log the complete trade on the chart
|
|
if self.chart:
|
|
# Show sell signal
|
|
if hasattr(self.chart, 'add_nn_signal'):
|
|
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
|
|
|
|
# Record the trade with PnL
|
|
if hasattr(self.chart, 'add_trade'):
|
|
self.chart.add_trade(
|
|
price=price,
|
|
timestamp=timestamp,
|
|
pnl=pnl,
|
|
amount=0.1,
|
|
action=action_str,
|
|
type=action_str # Add explicit type
|
|
)
|
|
|
|
# Update trade counts
|
|
self.trade_count += 1
|
|
if pnl is not None and pnl > 0:
|
|
self.win_count += 1
|
|
|
|
# Reset entry data
|
|
self.entry_price = None
|
|
self.entry_time = None
|
|
|
|
# Track all actions
|
|
self.action_history.append({
|
|
'step': step,
|
|
'action': action_str,
|
|
'price': price,
|
|
'reward': reward,
|
|
'timestamp': timestamp.isoformat()
|
|
})
|
|
else:
|
|
# Hold action
|
|
action_str = "HOLD"
|
|
timestamp = datetime.now()
|
|
|
|
# Update chart trading info
|
|
if self.chart and hasattr(self.chart, 'update_trading_info'):
|
|
# Determine current position size (0.1 if in position, 0 if not)
|
|
position_size = 0.1 if self.in_position else 0.0
|
|
self.chart.update_trading_info(
|
|
signal=action_str,
|
|
position=position_size,
|
|
balance=self.session_balance,
|
|
pnl=self.session_pnl
|
|
)
|
|
|
|
# Track reward for all actions (including hold)
|
|
self.reward_history.append(reward)
|
|
|
|
# Log periodically
|
|
if len(self.reward_history) % 100 == 0:
|
|
avg_reward = sum(self.reward_history[-100:]) / 100
|
|
logger.info(f"Step {step}: Avg reward (last 100): {avg_reward:.4f}, Actions: {len(self.action_history)}, Trades: {self.trade_count}")
|
|
|
|
def on_episode(episode, reward, info):
|
|
"""Callback for each completed episode"""
|
|
self.episode_count += 1
|
|
|
|
# Log episode results
|
|
logger.info(f"Episode {episode} completed")
|
|
logger.info(f" Total reward: {reward:.4f}")
|
|
logger.info(f" PnL: {info['gain']:.4f}")
|
|
logger.info(f" Win rate: {info['win_rate']:.4f}")
|
|
logger.info(f" Trades: {info['trades']}")
|
|
|
|
# Log session-wide PnL
|
|
session_win_rate = self.session_wins / self.session_trades if self.session_trades > 0 else 0
|
|
logger.info(f" Session Balance: ${self.session_balance:.2f}")
|
|
logger.info(f" Session Total PnL: {self.session_pnl:.4f}")
|
|
logger.info(f" Session Win Rate: {session_win_rate:.4f}")
|
|
logger.info(f" Session Trades: {self.session_trades}")
|
|
|
|
# Update chart trading info with final episode information
|
|
if self.chart and hasattr(self.chart, 'update_trading_info'):
|
|
# Reset position since we're between episodes
|
|
self.chart.update_trading_info(
|
|
signal="HOLD",
|
|
position=0.0,
|
|
balance=self.session_balance,
|
|
pnl=self.session_pnl
|
|
)
|
|
|
|
# Reset position state for new episode
|
|
self.in_position = False
|
|
self.entry_price = None
|
|
self.entry_time = None
|
|
|
|
# After each episode, perform additional training for local extrema
|
|
if hasattr(self.agent, 'policy_net') and hasattr(self.agent, 'replay') and episode > 0:
|
|
self._train_on_extrema(self.agent, info['env'])
|
|
|
|
# Start the actual training with our callbacks
|
|
self.agent = train_rl(
|
|
num_episodes=num_episodes,
|
|
max_steps=max_steps,
|
|
save_path=self.model_save_path,
|
|
action_callback=on_action,
|
|
episode_callback=on_episode,
|
|
symbol=self.symbol
|
|
)
|
|
|
|
logger.info("RL training completed")
|
|
return self.agent
|
|
|
|
def _train_on_extrema(self, agent, env):
|
|
"""
|
|
Perform additional training on local extrema (tops and bottoms)
|
|
to help the model learn these important patterns faster
|
|
|
|
Args:
|
|
agent: The DQN agent
|
|
env: The trading environment
|
|
"""
|
|
if not hasattr(env, 'features_1m') or len(env.features_1m) == 0:
|
|
logger.warning("Environment doesn't have price data for extrema detection")
|
|
return
|
|
|
|
try:
|
|
# Extract close prices
|
|
prices = env.features_1m[:, -1] # Assuming close price is the last column
|
|
|
|
# Find local extrema
|
|
max_indices, min_indices = self.extrema_detector.find_extrema(prices)
|
|
|
|
if len(max_indices) == 0 or len(min_indices) == 0:
|
|
logger.warning("No extrema found in the current price data")
|
|
return
|
|
|
|
logger.info(f"Found {len(max_indices)} tops and {len(min_indices)} bottoms for additional training")
|
|
|
|
# Calculate price changes at extrema to prioritize more significant ones
|
|
max_price_changes = []
|
|
for idx in max_indices:
|
|
if idx < 5 or idx >= len(prices) - 5:
|
|
continue
|
|
# Calculate percentage price rise from previous 5 candles to the peak
|
|
min_before = min(prices[idx-5:idx])
|
|
price_change = (prices[idx] - min_before) / min_before
|
|
max_price_changes.append((idx, price_change))
|
|
|
|
min_price_changes = []
|
|
for idx in min_indices:
|
|
if idx < 5 or idx >= len(prices) - 5:
|
|
continue
|
|
# Calculate percentage price drop from previous 5 candles to the bottom
|
|
max_before = max(prices[idx-5:idx])
|
|
price_change = (max_before - prices[idx]) / max_before
|
|
min_price_changes.append((idx, price_change))
|
|
|
|
# Sort extrema by significance (larger price change is more important)
|
|
max_price_changes.sort(key=lambda x: x[1], reverse=True)
|
|
min_price_changes.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
# Take top 10 most significant extrema or all if fewer
|
|
max_indices = [idx for idx, _ in max_price_changes[:10]]
|
|
min_indices = [idx for idx, _ in min_price_changes[:10]]
|
|
|
|
# Log the significance of the extrema
|
|
if max_indices:
|
|
logger.info(f"Top extrema price changes: {[round(pc*100, 2) for _, pc in max_price_changes[:5]]}%")
|
|
if min_indices:
|
|
logger.info(f"Bottom extrema price changes: {[round(pc*100, 2) for _, pc in min_price_changes[:5]]}%")
|
|
|
|
# Collect states, actions, rewards for batch training
|
|
states = []
|
|
actions = []
|
|
rewards = []
|
|
next_states = []
|
|
dones = []
|
|
|
|
# Process tops (local maxima - should sell)
|
|
for idx in max_indices:
|
|
if idx < env.window_size + 2 or idx >= len(prices) - 2:
|
|
continue
|
|
|
|
# Create states for multiple points approaching the top
|
|
# This helps the model learn to recognize the pattern leading to the top
|
|
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the top
|
|
if idx - offset < env.window_size:
|
|
continue
|
|
|
|
# State before the peak
|
|
state_idx = idx - offset
|
|
env.current_step = state_idx
|
|
state = env._get_observation()
|
|
|
|
# The next state would be closer to the peak
|
|
env.current_step = state_idx + 1
|
|
next_state = env._get_observation()
|
|
|
|
# Reward increases as we get closer to the peak
|
|
# Stronger rewards for being right at the peak
|
|
reward = 1.0 if offset > 1 else 2.0
|
|
|
|
# Add to memory
|
|
action = 1 # Sell
|
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
|
|
|
# Add to batch
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_states.append(next_state)
|
|
dones.append(False)
|
|
|
|
# Process bottoms (local minima - should buy)
|
|
for idx in min_indices:
|
|
if idx < env.window_size + 2 or idx >= len(prices) - 2:
|
|
continue
|
|
|
|
# Create states for multiple points approaching the bottom
|
|
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the bottom
|
|
if idx - offset < env.window_size:
|
|
continue
|
|
|
|
# State before the bottom
|
|
state_idx = idx - offset
|
|
env.current_step = state_idx
|
|
state = env._get_observation()
|
|
|
|
# The next state would be closer to the bottom
|
|
env.current_step = state_idx + 1
|
|
next_state = env._get_observation()
|
|
|
|
# Reward increases as we get closer to the bottom
|
|
reward = 1.0 if offset > 1 else 2.0
|
|
|
|
# Add to memory
|
|
action = 0 # Buy
|
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
|
|
|
# Add to batch
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_states.append(next_state)
|
|
dones.append(False)
|
|
|
|
# Add some negative examples - don't buy at tops, don't sell at bottoms
|
|
for idx in max_indices[:5]: # Use a few top peaks
|
|
if idx < env.window_size + 1 or idx >= len(prices) - 1:
|
|
continue
|
|
|
|
# State at the peak
|
|
env.current_step = idx
|
|
state = env._get_observation()
|
|
|
|
# Next state
|
|
env.current_step = idx + 1
|
|
next_state = env._get_observation()
|
|
|
|
# Strong negative reward for buying at a peak
|
|
reward = -1.5
|
|
|
|
# Add negative example of buying at a peak
|
|
action = 0 # Buy (wrong action)
|
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
|
|
|
# Add to batch
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_states.append(next_state)
|
|
dones.append(False)
|
|
|
|
for idx in min_indices[:5]: # Use a few bottom troughs
|
|
if idx < env.window_size + 1 or idx >= len(prices) - 1:
|
|
continue
|
|
|
|
# State at the bottom
|
|
env.current_step = idx
|
|
state = env._get_observation()
|
|
|
|
# Next state
|
|
env.current_step = idx + 1
|
|
next_state = env._get_observation()
|
|
|
|
# Strong negative reward for selling at a bottom
|
|
reward = -1.5
|
|
|
|
# Add negative example of selling at a bottom
|
|
action = 1 # Sell (wrong action)
|
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
|
|
|
# Add to batch
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_states.append(next_state)
|
|
dones.append(False)
|
|
|
|
# Train on the collected extrema samples
|
|
if len(states) > 0:
|
|
logger.info(f"Performing additional training on {len(states)} extrema patterns")
|
|
loss = agent.train_on_extrema(states, actions, rewards, next_states, dones)
|
|
logger.info(f"Extrema training loss: {loss:.4f}")
|
|
|
|
# Additional replay passes with extrema samples included
|
|
for _ in range(5):
|
|
loss = agent.replay(use_extrema=True)
|
|
logger.info(f"Mixed replay with extrema - loss: {loss:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during extrema training: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
async def start_realtime_chart(symbol="BTC/USDT", port=8050):
|
|
"""
|
|
Start the realtime chart display in a separate thread
|
|
|
|
Returns:
|
|
tuple: (chart, websocket_task)
|
|
"""
|
|
from realtime import RealTimeChart
|
|
|
|
try:
|
|
logger.info(f"Initializing RealTimeChart for {symbol}")
|
|
# Create the chart with sample data enabled and no-ticks warnings disabled
|
|
chart = RealTimeChart(symbol, use_sample_data=True, log_no_ticks_warning=False)
|
|
|
|
# Start the WebSocket connection in a separate thread
|
|
# The _start_websocket_thread method already handles this correctly
|
|
|
|
# Run the Dash server in a separate thread
|
|
thread = Thread(target=lambda c=chart, p=port: c.run(host='localhost', port=p))
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
# Give the server a moment to start
|
|
await asyncio.sleep(2)
|
|
|
|
logger.info(f"Started realtime chart for {symbol} on port {port}")
|
|
logger.info(f"You can view the chart at http://localhost:{port}/")
|
|
|
|
# Return the chart and a dummy websocket task (the real one is running in a thread)
|
|
return chart, asyncio.create_task(asyncio.sleep(0))
|
|
except Exception as e:
|
|
logger.error(f"Error starting realtime chart: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_training_thread(chart):
|
|
"""Start the RL training in a separate thread"""
|
|
integrator = RLTrainingIntegrator(chart)
|
|
|
|
def training_thread_func():
|
|
try:
|
|
# Use a small number of episodes to test termination handling
|
|
integrator.start_training(num_episodes=2, max_steps=500)
|
|
except Exception as e:
|
|
logger.error(f"Error in training thread: {str(e)}")
|
|
|
|
thread = threading.Thread(target=training_thread_func)
|
|
thread.daemon = True
|
|
thread.start()
|
|
logger.info("Started RL training thread")
|
|
return thread, integrator
|
|
|
|
def test_signals(chart):
|
|
"""Add test signals to the chart to verify functionality"""
|
|
from datetime import datetime
|
|
|
|
logger.info("Adding test signals to chart")
|
|
|
|
# Add a test BUY signal
|
|
chart.add_nn_signal("BUY", datetime.now(), 0.95)
|
|
|
|
# Sleep briefly
|
|
time.sleep(1)
|
|
|
|
# Add a test SELL signal
|
|
chart.add_nn_signal("SELL", datetime.now(), 0.85)
|
|
|
|
# Add a test trade if the method exists
|
|
if hasattr(chart, 'add_trade'):
|
|
chart.add_trade(
|
|
price=83000.0,
|
|
timestamp=datetime.now(),
|
|
pnl=0.05,
|
|
action="BUY",
|
|
type="BUY" # Add explicit type
|
|
)
|
|
else:
|
|
logger.warning("RealTimeChart has no add_trade method - skipping test trade")
|
|
|
|
async def main():
|
|
"""Main function that coordinates the realtime chart and RL training"""
|
|
global realtime_chart, realtime_websocket_task, running
|
|
|
|
logger.info("Starting integrated RL training with realtime visualization")
|
|
|
|
# Start the realtime chart
|
|
realtime_chart, realtime_websocket_task = await start_realtime_chart()
|
|
|
|
# Wait a bit for the chart to initialize
|
|
await asyncio.sleep(5)
|
|
|
|
# Test signals first
|
|
test_signals(realtime_chart)
|
|
|
|
# Start the training in a separate thread
|
|
training_thread, integrator = run_training_thread(realtime_chart)
|
|
|
|
try:
|
|
# Keep the main task running until interrupted
|
|
while running and training_thread.is_alive():
|
|
await asyncio.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("Shutting down...")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error: {str(e)}")
|
|
finally:
|
|
# Log final PnL summary
|
|
if hasattr(integrator, 'session_pnl'):
|
|
session_win_rate = integrator.session_wins / integrator.session_trades if integrator.session_trades > 0 else 0
|
|
logger.info("=" * 50)
|
|
logger.info("FINAL SESSION SUMMARY")
|
|
logger.info("=" * 50)
|
|
logger.info(f"Final Session Balance: ${integrator.session_balance:.2f}")
|
|
logger.info(f"Total Session PnL: {integrator.session_pnl:.4f}")
|
|
logger.info(f"Total Session Win Rate: {session_win_rate:.4f} ({integrator.session_wins}/{integrator.session_trades})")
|
|
logger.info(f"Total Session Trades: {integrator.session_trades}")
|
|
logger.info("=" * 50)
|
|
|
|
# Clean up
|
|
if realtime_websocket_task:
|
|
realtime_websocket_task.cancel()
|
|
try:
|
|
await realtime_websocket_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
logger.info("Application terminated")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Application terminated by user") |