977 lines
42 KiB
Python
977 lines
42 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Integrated RL Trading with Realtime Visualization
|
|
|
|
This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py)
|
|
to display the actions taken by the RL agent on the realtime chart.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
import signal
|
|
import numpy as np
|
|
import torch
|
|
import json
|
|
from threading import Thread
|
|
import pandas as pd
|
|
import argparse
|
|
from scipy.signal import argrelextrema
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger('rl_realtime')
|
|
|
|
# Add the project root to path if needed
|
|
project_root = os.path.dirname(os.path.abspath(__file__))
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
# Global variables for coordination
|
|
realtime_chart = None
|
|
realtime_websocket_task = None
|
|
running = True
|
|
chart_instance = None # Global reference to the chart instance
|
|
|
|
def signal_handler(sig, frame):
|
|
"""Handle CTRL+C to gracefully exit training"""
|
|
global running
|
|
logger.info("Received interrupt signal. Finishing current epoch and shutting down...")
|
|
running = False
|
|
|
|
# Register signal handler
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
class ExtremaDetector:
|
|
"""
|
|
Detects local extrema (tops and bottoms) in price data
|
|
"""
|
|
def __init__(self, window_size=10, order=5):
|
|
"""
|
|
Args:
|
|
window_size (int): Size of the window to look for extrema
|
|
order (int): How many points on each side to use for comparison
|
|
"""
|
|
self.window_size = window_size
|
|
self.order = order
|
|
|
|
def find_extrema(self, prices):
|
|
"""
|
|
Find the local minima and maxima in the price series
|
|
|
|
Args:
|
|
prices (array-like): Array of price values
|
|
|
|
Returns:
|
|
tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur
|
|
"""
|
|
# Convert to numpy array if needed
|
|
price_array = np.array(prices)
|
|
|
|
# Find local maxima (tops)
|
|
local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0]
|
|
|
|
# Find local minima (bottoms)
|
|
local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0]
|
|
|
|
# Filter out extrema that are too close to the edges
|
|
max_indices = local_max_indices[local_max_indices >= self.order]
|
|
max_indices = max_indices[max_indices < len(price_array) - self.order]
|
|
|
|
min_indices = local_min_indices[local_min_indices >= self.order]
|
|
min_indices = min_indices[min_indices < len(price_array) - self.order]
|
|
|
|
return max_indices, min_indices
|
|
|
|
class RLTrainingIntegrator:
|
|
"""
|
|
Integrates RL training with realtime chart visualization.
|
|
Acts as a bridge between the RL training process and the realtime chart.
|
|
"""
|
|
def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent", max_position=1.0):
|
|
self.chart = chart
|
|
self.symbol = symbol
|
|
self.model_save_path = model_save_path
|
|
self.episode_count = 0
|
|
self.action_history = []
|
|
self.reward_history = []
|
|
self.trade_count = 0
|
|
self.win_count = 0
|
|
|
|
# Maximum position size
|
|
self.max_position = max_position
|
|
|
|
# Add session-wide PnL tracking
|
|
self.session_pnl = 0.0
|
|
self.session_trades = 0
|
|
self.session_wins = 0
|
|
self.session_balance = 100.0 # Start with $100 balance
|
|
|
|
# Track current position state
|
|
self.current_position_size = 0.0
|
|
self.entry_price = None
|
|
self.entry_time = None
|
|
|
|
# Extrema detector
|
|
self.extrema_detector = ExtremaDetector(window_size=20, order=10)
|
|
|
|
# Store the agent reference
|
|
self.agent = None
|
|
|
|
# Price history for extrema detection
|
|
self.price_history = []
|
|
self.price_history_max_len = 100 # Store last 100 prices
|
|
|
|
# TensorBoard writer
|
|
self.tensorboard_writer = None
|
|
|
|
def _train_on_extrema(self, agent, env):
|
|
"""Train the agent specifically on local extrema points"""
|
|
if not hasattr(env, 'data') or not hasattr(env, 'original_data'):
|
|
logger.warning("Environment doesn't have required data attributes for extrema training")
|
|
return
|
|
|
|
# Extract price data
|
|
try:
|
|
prices = env.original_data['close'].values
|
|
|
|
# Find local extrema in the price series
|
|
max_indices, min_indices = self.extrema_detector.find_extrema(prices)
|
|
|
|
# Create training examples for extrema points
|
|
states = []
|
|
actions = []
|
|
rewards = []
|
|
next_states = []
|
|
dones = []
|
|
|
|
# For each bottom, create a BUY example
|
|
for idx in min_indices:
|
|
if idx < env.window_size or idx >= len(prices) - 2:
|
|
continue # Skip if too close to edges
|
|
|
|
# Set up the environment state at this point
|
|
env.current_step = idx
|
|
state = env._get_observation()
|
|
|
|
# The action should be BUY at bottoms
|
|
action = 0 # BUY
|
|
|
|
# Execute step to get next state and reward
|
|
env.position = 0 # Ensure no position before buying
|
|
env.current_step = idx # Reset position
|
|
next_state, reward, done, _ = env.step(action)
|
|
|
|
# Store this example
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(1.0) # Override with higher reward
|
|
next_states.append(next_state)
|
|
dones.append(done)
|
|
|
|
# Also add a HOLD example for already having a position at bottom
|
|
env.current_step = idx
|
|
env.position = 1 # Already have a position
|
|
state = env._get_observation()
|
|
action = 2 # HOLD
|
|
next_state, reward, done, _ = env.step(action)
|
|
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(0.5) # Good to hold at bottom with a position
|
|
next_states.append(next_state)
|
|
dones.append(done)
|
|
|
|
# For each top, create a SELL example
|
|
for idx in max_indices:
|
|
if idx < env.window_size or idx >= len(prices) - 2:
|
|
continue # Skip if too close to edges
|
|
|
|
# Set up the environment state at this point
|
|
env.current_step = idx
|
|
|
|
# The action should be SELL at tops (if we have a position)
|
|
env.position = 1 # Set position to 1 (we have a long position)
|
|
env.entry_price = prices[idx-5] # Pretend we bought a bit earlier
|
|
state = env._get_observation()
|
|
action = 1 # SELL
|
|
|
|
# Execute step to get next state and reward
|
|
next_state, reward, done, _ = env.step(action)
|
|
|
|
# Store this example
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(1.0) # Override with higher reward
|
|
next_states.append(next_state)
|
|
dones.append(done)
|
|
|
|
# Also add a HOLD example for not having a position at top
|
|
env.current_step = idx
|
|
env.position = 0 # No position
|
|
state = env._get_observation()
|
|
action = 2 # HOLD
|
|
next_state, reward, done, _ = env.step(action)
|
|
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(0.5) # Good to hold at top with no position
|
|
next_states.append(next_state)
|
|
dones.append(done)
|
|
|
|
# Check if we have any extrema examples
|
|
if states:
|
|
logger.info(f"Training on {len(states)} extrema examples: {len(min_indices)} bottoms, {len(max_indices)} tops")
|
|
# Convert to numpy arrays
|
|
states = np.array(states)
|
|
actions = np.array(actions)
|
|
rewards = np.array(rewards)
|
|
next_states = np.array(next_states)
|
|
dones = np.array(dones)
|
|
|
|
# Train the agent on these examples
|
|
loss = agent.train_on_extrema(states, actions, rewards, next_states, dones)
|
|
logger.info(f"Extrema training loss: {loss:.4f}")
|
|
else:
|
|
logger.info("No valid extrema examples found for training")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during extrema training: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
def run_training(self, episodes=100, max_steps=2000):
|
|
"""Run the training process with our integrations"""
|
|
from NN.train_rl import train_rl, RLTradingEnvironment
|
|
import time
|
|
|
|
# Create a stop event for training interruption
|
|
self.stop_event = threading.Event()
|
|
|
|
# Reset session metrics
|
|
self.session_pnl = 0.0
|
|
self.session_trades = 0
|
|
self.session_wins = 0
|
|
self.session_balance = 100.0
|
|
self.session_step = 0
|
|
self.current_position_size = 0.0
|
|
|
|
# Reset price history
|
|
self.price_history = []
|
|
|
|
# Reset chart-related state if it exists
|
|
if self.chart:
|
|
# Reset positions list to empty
|
|
if hasattr(self.chart, 'positions'):
|
|
self.chart.positions = []
|
|
|
|
# Reset accumulated PnL and balance display
|
|
if hasattr(self.chart, 'accumulative_pnl'):
|
|
self.chart.accumulative_pnl = 0.0
|
|
|
|
if hasattr(self.chart, 'current_balance'):
|
|
self.chart.current_balance = 100.0
|
|
|
|
# Update trading info if method exists
|
|
if hasattr(self.chart, 'update_trading_info'):
|
|
self.chart.update_trading_info(
|
|
signal="READY",
|
|
position=0.0,
|
|
balance=self.session_balance,
|
|
pnl=0.0
|
|
)
|
|
|
|
# Initialize TensorBoard writer
|
|
try:
|
|
log_dir = f'runs/rl_realtime_{int(time.time())}'
|
|
self.tensorboard_writer = SummaryWriter(log_dir=log_dir)
|
|
logger.info(f"TensorBoard logging enabled at {log_dir}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize TensorBoard writer: {str(e)}")
|
|
self.tensorboard_writer = None
|
|
|
|
try:
|
|
logger.info(f"Starting training for {episodes} episodes (max {max_steps} steps per episode)")
|
|
|
|
# Create a custom environment class that includes our reward function modification
|
|
class EnhancedRLTradingEnvironment(RLTradingEnvironment):
|
|
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001):
|
|
"""Initialize with normalization parameters"""
|
|
super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee)
|
|
# Initialize integrator and chart references
|
|
self.integrator = None # Will be set after initialization
|
|
self.chart = None # Will be set after initialization
|
|
# Make writer accessible to integrator callbacks
|
|
self.writer = None # Will be set by train_rl
|
|
|
|
def set_tensorboard_writer(self, writer):
|
|
"""Set the TensorBoard writer"""
|
|
self.writer = writer
|
|
|
|
def _calculate_reward(self, action):
|
|
"""Override the reward calculation with our enhanced version"""
|
|
# Get the original reward calculation result
|
|
reward, pnl = super()._calculate_reward(action)
|
|
|
|
# Get current price (normalized from training data)
|
|
current_price = self.features_1m[self.current_step, -1]
|
|
|
|
# Get real market price if available
|
|
real_market_price = None
|
|
if hasattr(self, 'chart') and self.chart and hasattr(self.chart, 'latest_price'):
|
|
real_market_price = self.chart.latest_price
|
|
|
|
# Pass through the integrator's reward modifier
|
|
if hasattr(self, 'integrator') and self.integrator is not None:
|
|
# Add price to history - use real market price if available
|
|
if real_market_price is not None:
|
|
# For extrema detection, use a normalized version of the real price
|
|
# to keep scale consistent with the model's price history
|
|
self.integrator.price_history.append(current_price)
|
|
else:
|
|
self.integrator.price_history.append(current_price)
|
|
|
|
# Apply extrema-based reward modifications
|
|
if len(self.integrator.price_history) > 20:
|
|
# Detect local extrema
|
|
tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema(
|
|
self.integrator.price_history
|
|
)
|
|
|
|
# Calculate additional rewards based on extrema
|
|
if action == 0 and bottoms_indices and bottoms_indices[-1] > len(self.integrator.price_history) - 5:
|
|
# Bonus for buying near bottoms
|
|
reward += 0.01
|
|
if self.integrator.session_step % 50 == 0: # Log less frequently
|
|
# Display the real market price if available
|
|
display_price = real_market_price if real_market_price is not None else current_price
|
|
logger.info(f"BUY signal near bottom detected at price {display_price:.2f}! Adding bonus reward.")
|
|
|
|
elif action == 1 and tops_indices and tops_indices[-1] > len(self.integrator.price_history) - 5:
|
|
# Bonus for selling near tops
|
|
reward += 0.01
|
|
if self.integrator.session_step % 50 == 0: # Log less frequently
|
|
# Display the real market price if available
|
|
display_price = real_market_price if real_market_price is not None else current_price
|
|
logger.info(f"SELL signal near top detected at price {display_price:.2f}! Adding bonus reward.")
|
|
|
|
return reward, pnl
|
|
|
|
# Create a custom environment class factory
|
|
def create_enhanced_env(features_1m, features_5m, features_15m):
|
|
env = EnhancedRLTradingEnvironment(features_1m, features_5m, features_15m)
|
|
# Set the integrator after creation
|
|
env.integrator = self
|
|
# Set the chart from the integrator
|
|
env.chart = self.chart
|
|
# Pass our TensorBoard writer to the environment
|
|
if self.tensorboard_writer:
|
|
env.set_tensorboard_writer(self.tensorboard_writer)
|
|
return env
|
|
|
|
# Run the training with callbacks
|
|
agent, env = train_rl(
|
|
symbol=self.symbol,
|
|
num_episodes=episodes,
|
|
max_steps=max_steps,
|
|
action_callback=self.on_action,
|
|
episode_callback=self.on_episode,
|
|
save_path=self.model_save_path,
|
|
env_class=create_enhanced_env # Use our enhanced environment
|
|
)
|
|
|
|
rewards = [] # Empty rewards since train_rl doesn't return them
|
|
info = {} # Empty info since train_rl doesn't return it
|
|
|
|
self.agent = agent
|
|
|
|
# Log final training results
|
|
logger.info("Training completed.")
|
|
logger.info(f"Final session balance: ${self.session_balance:.2f}")
|
|
logger.info(f"Final session PnL: {self.session_pnl:.4f}")
|
|
logger.info(f"Final win rate: {self.session_wins/max(1, self.session_trades):.4f}")
|
|
|
|
# Return the trained agent and environment
|
|
return agent, env
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during training: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
finally:
|
|
# Close TensorBoard writer if it exists
|
|
if self.tensorboard_writer:
|
|
try:
|
|
self.tensorboard_writer.close()
|
|
except:
|
|
pass
|
|
self.tensorboard_writer = None
|
|
|
|
# Clear the stop event
|
|
self.stop_event.clear()
|
|
|
|
return None, None
|
|
|
|
def modify_reward_function(self, env):
|
|
"""Modify the reward function to emphasize finding bottoms and tops"""
|
|
# Store the original calculate_reward method
|
|
original_calculate_reward = env._calculate_reward
|
|
|
|
def enhanced_calculate_reward(action):
|
|
"""Enhanced reward function that rewards finding bottoms and tops"""
|
|
# Call the original reward function to get baseline reward
|
|
reward, pnl = original_calculate_reward(action)
|
|
|
|
# Check if we have enough price history for extrema detection
|
|
if len(self.price_history) > 20:
|
|
# Detect local extrema
|
|
tops_indices, bottoms_indices = self.extrema_detector.find_extrema(self.price_history)
|
|
|
|
# Get current price
|
|
current_price = self.price_history[-1]
|
|
|
|
# Calculate average price movement
|
|
avg_price_move = np.std(self.price_history)
|
|
|
|
# Check if current position is near a local extrema
|
|
is_near_bottom = False
|
|
is_near_top = False
|
|
|
|
# Find nearest bottom
|
|
if len(bottoms_indices) > 0:
|
|
nearest_bottom_idx = bottoms_indices[-1]
|
|
if nearest_bottom_idx > len(self.price_history) - 5: # Bottom detected in last 5 ticks
|
|
nearest_bottom_price = self.price_history[nearest_bottom_idx]
|
|
# Check if price is within 0.3% of the bottom
|
|
if abs(current_price - nearest_bottom_price) / nearest_bottom_price < 0.003:
|
|
is_near_bottom = True
|
|
|
|
# Find nearest top
|
|
if len(tops_indices) > 0:
|
|
nearest_top_idx = tops_indices[-1]
|
|
if nearest_top_idx > len(self.price_history) - 5: # Top detected in last 5 ticks
|
|
nearest_top_price = self.price_history[nearest_top_idx]
|
|
# Check if price is within 0.3% of the top
|
|
if abs(current_price - nearest_top_price) / nearest_top_price < 0.003:
|
|
is_near_top = True
|
|
|
|
# Apply bonus rewards for finding extrema
|
|
if action == 0: # BUY
|
|
if is_near_bottom:
|
|
# Big bonus for buying near bottom
|
|
logger.info(f"BUY signal near bottom detected! Adding bonus reward.")
|
|
reward += 0.01 # Significant bonus
|
|
elif is_near_top:
|
|
# Penalty for buying near top
|
|
logger.info(f"BUY signal near top detected! Adding penalty.")
|
|
reward -= 0.01 # Significant penalty
|
|
elif action == 1: # SELL
|
|
if is_near_top:
|
|
# Big bonus for selling near top
|
|
logger.info(f"SELL signal near top detected! Adding bonus reward.")
|
|
reward += 0.01 # Significant bonus
|
|
elif is_near_bottom:
|
|
# Penalty for selling near bottom
|
|
logger.info(f"SELL signal near bottom detected! Adding penalty.")
|
|
reward -= 0.01 # Significant penalty
|
|
|
|
# Add bonus for holding during appropriate times
|
|
if action == 2: # HOLD
|
|
if (is_near_bottom and self.current_position_size > 0) or \
|
|
(is_near_top and self.current_position_size == 0):
|
|
# Good to hold if we have positions at bottom or no positions at top
|
|
reward += 0.001 # Small bonus for correct holding
|
|
|
|
return reward, pnl
|
|
|
|
# Replace the reward function with our enhanced version
|
|
env._calculate_reward = enhanced_calculate_reward
|
|
|
|
return env
|
|
|
|
def on_action(self, step, action, price, reward, info):
|
|
"""Called after each action in the episode"""
|
|
|
|
# Log the action
|
|
action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
|
|
|
|
# Get real market price from chart if available, otherwise use the model price
|
|
display_price = price
|
|
if self.chart and hasattr(self.chart, 'latest_price') and self.chart.latest_price is not None:
|
|
display_price = self.chart.latest_price
|
|
elif abs(price) < 0.1: # If price is likely normalized (very small)
|
|
# Fallback to approximate price if no real market data
|
|
display_price = 1920.0 * (1 + price * 0.10)
|
|
|
|
# Store the original price for model-related calculations
|
|
model_price = price
|
|
|
|
# Update price history for extrema detection (using model price)
|
|
self.price_history.append(model_price)
|
|
if len(self.price_history) > self.price_history_max_len:
|
|
self.price_history = self.price_history[-self.price_history_max_len:]
|
|
|
|
# Normalize rewards to be realistic for crypto trading (smaller values)
|
|
normalized_reward = reward * 0.1 # Scale down rewards
|
|
if abs(normalized_reward) > 5.0: # Cap maximum reward value
|
|
normalized_reward = 5.0 if normalized_reward > 0 else -5.0
|
|
|
|
# Update session PnL and balance
|
|
self.session_step += 1
|
|
self.session_pnl += normalized_reward
|
|
|
|
# Increase balance based on reward - cap to reasonable values
|
|
self.session_balance += normalized_reward
|
|
self.session_balance = min(self.session_balance, 1000.0) # Cap maximum balance
|
|
self.session_balance = max(self.session_balance, 0.0) # Prevent negative balance
|
|
|
|
# Update chart's accumulativePnL and balance if available
|
|
if self.chart:
|
|
if hasattr(self.chart, 'accumulative_pnl'):
|
|
self.chart.accumulative_pnl = self.session_pnl
|
|
# Cap accumulated PnL to reasonable values
|
|
self.chart.accumulative_pnl = min(self.chart.accumulative_pnl, 500.0)
|
|
self.chart.accumulative_pnl = max(self.chart.accumulative_pnl, -100.0)
|
|
|
|
if hasattr(self.chart, 'current_balance'):
|
|
self.chart.current_balance = self.session_balance
|
|
|
|
# Handle win/loss tracking
|
|
if reward != 0: # If this was a trade with P&L
|
|
self.session_trades += 1
|
|
if reward > 0:
|
|
self.session_wins += 1
|
|
|
|
# Log to TensorBoard if writer is available
|
|
if self.tensorboard_writer:
|
|
self.tensorboard_writer.add_scalar('Action/Type', action, self.session_step)
|
|
self.tensorboard_writer.add_scalar('Action/Price', display_price, self.session_step)
|
|
self.tensorboard_writer.add_scalar('Session/Balance', self.session_balance, self.session_step)
|
|
self.tensorboard_writer.add_scalar('Session/PnL', self.session_pnl, self.session_step)
|
|
self.tensorboard_writer.add_scalar('Session/Position', self.current_position_size, self.session_step)
|
|
|
|
# Track win rate
|
|
if self.session_trades > 0:
|
|
win_rate = self.session_wins / self.session_trades
|
|
self.tensorboard_writer.add_scalar('Session/WinRate', win_rate, self.session_step)
|
|
|
|
# Only log a subset of actions to avoid excessive output
|
|
if step % 100 == 0 or step < 10 or self.session_step % 100 == 0:
|
|
logger.info(f"Step {step}, Action: {action_str}, Price: {display_price:.2f}, Reward: {reward:.4f}, PnL: {self.session_pnl:.4f}, Balance: ${self.session_balance:.2f}, Position: {self.current_position_size:.2f}")
|
|
|
|
# Update chart with the action
|
|
if action == 0: # BUY
|
|
# Check if we've reached maximum position size
|
|
if self.current_position_size >= self.max_position:
|
|
logger.warning(f"Maximum position size reached ({self.max_position}). Ignoring BUY signal.")
|
|
# Don't add trade to chart, but keep session tracking consistent
|
|
else:
|
|
# Update position tracking
|
|
new_position = min(self.current_position_size + 0.1, self.max_position)
|
|
actual_buy_amount = new_position - self.current_position_size
|
|
self.current_position_size = new_position
|
|
|
|
# Only add to chart for visualization if we have a chart
|
|
if self.chart and hasattr(self.chart, "add_trade"):
|
|
# Adding a BUY trade
|
|
try:
|
|
self.chart.add_trade(
|
|
price=display_price, # Use denormalized price for display
|
|
timestamp=datetime.now(),
|
|
amount=actual_buy_amount, # Use actual amount bought
|
|
pnl=reward,
|
|
action="BUY"
|
|
)
|
|
self.chart.last_action = "BUY"
|
|
except Exception as e:
|
|
logger.error(f"Failed to add BUY trade to chart: {str(e)}")
|
|
|
|
# Log buy action to TensorBoard
|
|
if self.tensorboard_writer:
|
|
self.tensorboard_writer.add_scalar('Trade/Buy', display_price, self.session_step)
|
|
|
|
elif action == 1: # SELL
|
|
# Update position tracking
|
|
if self.current_position_size > 0:
|
|
# Calculate sell amount (all current position)
|
|
sell_amount = self.current_position_size
|
|
self.current_position_size = 0
|
|
|
|
# Only add to chart for visualization if we have a chart
|
|
if self.chart and hasattr(self.chart, "add_trade"):
|
|
# Adding a SELL trade
|
|
try:
|
|
self.chart.add_trade(
|
|
price=display_price, # Use denormalized price for display
|
|
timestamp=datetime.now(),
|
|
amount=sell_amount, # Sell all current position
|
|
pnl=reward,
|
|
action="SELL"
|
|
)
|
|
self.chart.last_action = "SELL"
|
|
except Exception as e:
|
|
logger.error(f"Failed to add SELL trade to chart: {str(e)}")
|
|
|
|
# Log sell action to TensorBoard
|
|
if self.tensorboard_writer:
|
|
self.tensorboard_writer.add_scalar('Trade/Sell', display_price, self.session_step)
|
|
self.tensorboard_writer.add_scalar('Trade/PnL', reward, self.session_step)
|
|
else:
|
|
logger.warning("No position to sell. Ignoring SELL signal.")
|
|
|
|
# Update the trading info display on chart
|
|
if self.chart and hasattr(self.chart, "update_trading_info"):
|
|
try:
|
|
# Update the trading info panel with latest data
|
|
self.chart.update_trading_info(
|
|
signal=action_str,
|
|
position=self.current_position_size,
|
|
balance=self.session_balance,
|
|
pnl=self.session_pnl
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update trading info: {str(e)}")
|
|
|
|
# Check for manual termination
|
|
if self.stop_event.is_set():
|
|
return False # Signal to stop episode
|
|
|
|
return True # Continue episode
|
|
|
|
def on_episode(self, episode, reward, info):
|
|
"""Callback for each completed episode"""
|
|
self.episode_count += 1
|
|
|
|
# Log episode results
|
|
logger.info(f"Episode {episode} completed")
|
|
logger.info(f" Total reward: {reward:.4f}")
|
|
logger.info(f" PnL: {info['gain']:.4f}")
|
|
logger.info(f" Win rate: {info['win_rate']:.4f}")
|
|
logger.info(f" Trades: {info['trades']}")
|
|
|
|
# Log session-wide PnL
|
|
session_win_rate = self.session_wins / self.session_trades if self.session_trades > 0 else 0
|
|
logger.info(f" Session Balance: ${self.session_balance:.2f}")
|
|
logger.info(f" Session Total PnL: {self.session_pnl:.4f}")
|
|
logger.info(f" Session Win Rate: {session_win_rate:.4f}")
|
|
logger.info(f" Session Trades: {self.session_trades}")
|
|
|
|
# Update TensorBoard logging if we have access to the writer
|
|
if 'env' in info and hasattr(info['env'], 'writer'):
|
|
writer = info['env'].writer
|
|
writer.add_scalar('Session/Balance', self.session_balance, episode)
|
|
writer.add_scalar('Session/PnL', self.session_pnl, episode)
|
|
writer.add_scalar('Session/WinRate', session_win_rate, episode)
|
|
writer.add_scalar('Session/Trades', self.session_trades, episode)
|
|
writer.add_scalar('Session/Position', self.current_position_size, episode)
|
|
|
|
# Update chart trading info with final episode information
|
|
if self.chart and hasattr(self.chart, 'update_trading_info'):
|
|
# Reset position since we're between episodes
|
|
self.chart.update_trading_info(
|
|
signal="HOLD",
|
|
position=self.current_position_size,
|
|
balance=self.session_balance,
|
|
pnl=self.session_pnl
|
|
)
|
|
|
|
# Reset position state for new episode
|
|
self.current_position_size = 0.0
|
|
self.entry_price = None
|
|
self.entry_time = None
|
|
|
|
# Reset position list in the chart if it exists
|
|
if self.chart and hasattr(self.chart, 'positions'):
|
|
# Keep only the last 10 positions if we have more
|
|
if len(self.chart.positions) > 10:
|
|
self.chart.positions = self.chart.positions[-10:]
|
|
|
|
return True # Continue training
|
|
|
|
async def start_realtime_chart(symbol="BTC/USDT", port=8050, manual_mode=False):
|
|
"""Start the realtime chart
|
|
|
|
Args:
|
|
symbol (str): Trading symbol
|
|
port (int): Port to run the server on
|
|
manual_mode (bool): Enable manual trading mode
|
|
|
|
Returns:
|
|
tuple: (RealTimeChart instance, websocket task)
|
|
"""
|
|
from realtime import RealTimeChart
|
|
|
|
try:
|
|
logger.info(f"Initializing RealTimeChart for {symbol}")
|
|
# Create the chart with the simplified constructor
|
|
chart = RealTimeChart(symbol)
|
|
|
|
# Add backward compatibility methods
|
|
chart.add_trade = lambda price, timestamp, amount, pnl=0.0, action="BUY": _add_trade_compat(chart, price, timestamp, amount, pnl, action)
|
|
|
|
# Start the Dash server in a separate thread
|
|
dash_thread = Thread(target=lambda: chart.run(port=port))
|
|
dash_thread.daemon = True
|
|
dash_thread.start()
|
|
logger.info(f"Started Dash server thread on port {port}")
|
|
|
|
# Give the server a moment to start
|
|
await asyncio.sleep(2)
|
|
|
|
# Enable manual trading mode if requested
|
|
if manual_mode:
|
|
logger.info("Enabling manual trading mode")
|
|
logger.warning("Manual trading mode not supported by this simplified chart implementation")
|
|
|
|
logger.info(f"Started realtime chart for {symbol} on port {port}")
|
|
logger.info(f"You can view the chart at http://localhost:{port}/")
|
|
|
|
# Start websocket in the background
|
|
websocket_task = asyncio.create_task(chart.start_websocket())
|
|
|
|
# Return the chart and websocket task
|
|
return chart, websocket_task
|
|
except Exception as e:
|
|
logger.error(f"Error starting realtime chart: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"):
|
|
"""Compatibility function for adding trades to the chart"""
|
|
from realtime import Position
|
|
|
|
try:
|
|
# Create a new position
|
|
position = Position(
|
|
action=action,
|
|
entry_price=price,
|
|
amount=amount,
|
|
timestamp=timestamp,
|
|
fee_rate=0.001 # 0.1% fee rate
|
|
)
|
|
|
|
# For SELL actions, close the position with given PnL
|
|
if action == "SELL":
|
|
position.close(price, timestamp)
|
|
# Use realistic PnL values rather than the enormous ones from the model
|
|
# Cap PnL to reasonable values based on position size and price
|
|
max_reasonable_pnl = price * amount * 0.05 # Max 5% profit per trade
|
|
if abs(pnl) > max_reasonable_pnl:
|
|
if pnl > 0:
|
|
pnl = max_reasonable_pnl * 0.8 # Positive but reasonable
|
|
else:
|
|
pnl = -max_reasonable_pnl * 0.8 # Negative but reasonable
|
|
position.pnl = pnl
|
|
|
|
# Update chart's accumulated PnL if available
|
|
if hasattr(chart, 'accumulative_pnl'):
|
|
chart.accumulative_pnl += pnl
|
|
# Cap accumulated PnL to reasonable values
|
|
chart.accumulative_pnl = min(chart.accumulative_pnl, 500.0)
|
|
chart.accumulative_pnl = max(chart.accumulative_pnl, -100.0)
|
|
|
|
# Add to positions list, keeping only the last 200 for chart display
|
|
chart.positions.append(position)
|
|
if len(chart.positions) > 200:
|
|
chart.positions = chart.positions[-200:]
|
|
|
|
logger.info(f"Added {action} trade: price={price:.2f}, amount={amount}, pnl={pnl:.2f}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error adding trade: {str(e)}")
|
|
return False
|
|
|
|
def run_training_thread(chart, num_episodes=5000, skip_training=False, max_position=1.0):
|
|
"""Run the training thread with the chart integration"""
|
|
|
|
def training_thread_func():
|
|
"""Training thread function"""
|
|
try:
|
|
# Create the integrator object
|
|
integrator = RLTrainingIntegrator(
|
|
chart=chart,
|
|
symbol=chart.symbol if hasattr(chart, 'symbol') else "ETH/USDT",
|
|
max_position=max_position
|
|
)
|
|
|
|
# Attach it to the chart for manual access
|
|
if chart:
|
|
chart.integrator = integrator
|
|
|
|
# Wait for a bit to ensure chart is initialized
|
|
time.sleep(2)
|
|
|
|
# Run the training loop based on args
|
|
if skip_training:
|
|
logger.info("Skipping training as requested")
|
|
# Just load the model and test it
|
|
from NN.train_rl import RLTradingEnvironment, load_agent
|
|
agent = load_agent(integrator.model_save_path)
|
|
if agent:
|
|
logger.info("Loaded pre-trained agent")
|
|
integrator.agent = agent
|
|
else:
|
|
logger.warning("No pre-trained agent found")
|
|
else:
|
|
# Use a small number of episodes to test termination handling
|
|
logger.info(f"Starting training with {num_episodes} episodes and max_position={max_position}")
|
|
integrator.run_training(episodes=num_episodes, max_steps=2000)
|
|
except Exception as e:
|
|
logger.error(f"Error in training thread: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Create and start the thread
|
|
thread = threading.Thread(target=training_thread_func, daemon=True)
|
|
thread.start()
|
|
logger.info("Training thread started")
|
|
return thread
|
|
|
|
def test_signals(chart):
|
|
"""Add test signals and trades to the chart to verify functionality"""
|
|
from datetime import datetime
|
|
|
|
logger.info("Adding test trades to chart")
|
|
|
|
# Add test trades
|
|
if hasattr(chart, 'add_trade'):
|
|
# Get the real market price if available
|
|
base_price = 1920.0 # Default fallback price if real data is not available
|
|
|
|
if hasattr(chart, 'latest_price') and chart.latest_price is not None:
|
|
base_price = chart.latest_price
|
|
logger.info(f"Using real market price for test trades: ${base_price:.2f}")
|
|
else:
|
|
logger.warning(f"No real market price available, using fallback price: ${base_price:.2f}")
|
|
|
|
# Use slightly adjusted prices for buy/sell
|
|
buy_price = base_price * 0.995 # Slightly below market price
|
|
buy_amount = 0.1 # Standard amount for ETH
|
|
|
|
chart.add_trade(
|
|
price=buy_price,
|
|
timestamp=datetime.now(),
|
|
amount=buy_amount,
|
|
pnl=0.0, # No PnL for entry
|
|
action="BUY"
|
|
)
|
|
|
|
# Wait briefly
|
|
time.sleep(1)
|
|
|
|
# Add a SELL trade at a slightly higher price (profit)
|
|
sell_price = base_price * 1.005 # Slightly above market price
|
|
|
|
# Calculate PnL based on price difference
|
|
price_diff = sell_price - buy_price
|
|
pnl = price_diff * buy_amount
|
|
|
|
chart.add_trade(
|
|
price=sell_price,
|
|
timestamp=datetime.now(),
|
|
amount=buy_amount,
|
|
pnl=pnl,
|
|
action="SELL"
|
|
)
|
|
|
|
logger.info(f"Test trades added successfully: BUY at {buy_price:.2f}, SELL at {sell_price:.2f}, PnL: ${pnl:.2f}")
|
|
else:
|
|
logger.warning("RealTimeChart has no add_trade method - skipping test trades")
|
|
|
|
async def main():
|
|
"""Main function to run the integrated RL training with visualization"""
|
|
global chart_instance, realtime_chart
|
|
|
|
try:
|
|
# Start the realtime chart
|
|
logger.info(f"Starting realtime chart with {'manual mode' if args.manual_trades else 'auto mode'}")
|
|
chart, websocket_task = await start_realtime_chart(
|
|
symbol="ETH/USDT",
|
|
manual_mode=args.manual_trades
|
|
)
|
|
|
|
# Store references
|
|
chart_instance = chart
|
|
realtime_chart = chart
|
|
|
|
# Only run the visualization if requested
|
|
if args.visualize_only:
|
|
logger.info("Running visualization only")
|
|
# Test with random signals if not in manual mode
|
|
if not args.manual_trades:
|
|
test_signals(chart)
|
|
|
|
# Keep main thread running
|
|
while running:
|
|
await asyncio.sleep(1)
|
|
return
|
|
|
|
# Regular training mode
|
|
logger.info("Starting integrated RL training with visualization")
|
|
|
|
# Start the training thread
|
|
training_thread = run_training_thread(
|
|
chart=chart,
|
|
num_episodes=args.episodes,
|
|
skip_training=args.no_train,
|
|
max_position=args.max_position
|
|
)
|
|
|
|
# Keep main thread running
|
|
while training_thread.is_alive() and running:
|
|
await asyncio.sleep(1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main function: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
logger.info("Main function exiting")
|
|
|
|
if __name__ == "__main__":
|
|
# Set up argument parsing
|
|
parser = argparse.ArgumentParser(description='Train RL agent with real-time visualization')
|
|
parser.add_argument('--episodes', type=int, default=5000, help='Number of episodes to train')
|
|
parser.add_argument('--no-train', action='store_true', help='Skip training and just visualize')
|
|
parser.add_argument('--visualize-only', action='store_true', help='Only run visualization')
|
|
parser.add_argument('--manual-trades', action='store_true', help='Enable manual trading mode')
|
|
parser.add_argument('--log-file', type=str, default='rl_training.log', help='Log file name')
|
|
parser.add_argument('--max-position', type=float, default=1.0, help='Maximum position size')
|
|
|
|
# Parse the arguments
|
|
args = parser.parse_args()
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
filename=args.log_file,
|
|
filemode='a',
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
level=logging.INFO
|
|
)
|
|
# Add console output handler
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
console.setFormatter(formatter)
|
|
logging.getLogger('').addHandler(console)
|
|
|
|
logger.info("Starting RL training with real-time visualization")
|
|
logger.info(f"Episodes: {args.episodes}")
|
|
logger.info(f"No-train: {args.no_train}")
|
|
logger.info(f"Manual-trades: {args.manual_trades}")
|
|
logger.info(f"Max position size: {args.max_position}")
|
|
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Application terminated by user")
|
|
except Exception as e:
|
|
logger.error(f"Application error: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc()) |