realtime data in main
This commit is contained in:
parent
a8954bbf99
commit
1a1c410922
360
main.py
360
main.py
@ -38,6 +38,7 @@ from dash.dependencies import Input, Output, State
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
from threading import Thread
|
||||
import socket
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -971,12 +972,12 @@ class TradingEnvironment:
|
||||
|
||||
|
||||
def calculate_reward(self, action):
|
||||
"""Calculate reward for the given action with improved penalties for losing trades"""
|
||||
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
||||
reward = 0
|
||||
|
||||
# Base reward for actions
|
||||
if action == 0: # HOLD
|
||||
reward = -0.01 # Small penalty for doing nothing
|
||||
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
|
||||
|
||||
elif action == 1: # BUY/LONG
|
||||
if self.position == 'flat':
|
||||
@ -990,13 +991,26 @@ class TradingEnvironment:
|
||||
# Check if this is an optimal buy point (bottom)
|
||||
current_idx = len(self.features['price']) - 1
|
||||
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
||||
reward += 2.0 # Bonus for buying at a bottom
|
||||
reward += 3.0 # Increased bonus for buying at a bottom
|
||||
|
||||
# Check for volume spike (indicating potential big movement)
|
||||
if len(self.features['volume']) > 5:
|
||||
avg_volume = np.mean(self.features['volume'][-5:-1])
|
||||
current_volume = self.features['volume'][-1]
|
||||
if current_volume > avg_volume * 1.5:
|
||||
reward += 2.0 # Bonus for entering during high volume
|
||||
|
||||
# Check for price action signals
|
||||
if self.features['rsi'][-1] < 30: # Oversold condition
|
||||
reward += 1.5 # Bonus for buying at oversold levels
|
||||
|
||||
# Check if we're buying in a clear uptrend (good)
|
||||
if self.is_uptrend():
|
||||
reward += 1.0 # Bonus for buying in uptrend
|
||||
elif self.is_downtrend():
|
||||
reward -= 0.25 # Reduced penalty for buying in downtrend
|
||||
else:
|
||||
# Check if we're buying in a downtrend (bad)
|
||||
if self.is_downtrend():
|
||||
reward -= 0.5 # Penalty for buying in downtrend
|
||||
else:
|
||||
reward += 0.1 # Small reward for opening a position
|
||||
reward += 0.2 # Small reward for opening a position
|
||||
|
||||
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||
|
||||
@ -1066,9 +1080,26 @@ class TradingEnvironment:
|
||||
# Check if this is an optimal sell point (top)
|
||||
current_idx = len(self.features['price']) - 1
|
||||
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
||||
reward += 2.0 # Bonus for selling at a top
|
||||
reward += 3.0 # Increased bonus for selling at a top
|
||||
|
||||
# Check for volume spike
|
||||
if len(self.features['volume']) > 5:
|
||||
avg_volume = np.mean(self.features['volume'][-5:-1])
|
||||
current_volume = self.features['volume'][-1]
|
||||
if current_volume > avg_volume * 1.5:
|
||||
reward += 2.0 # Bonus for entering during high volume
|
||||
|
||||
# Check for price action signals
|
||||
if self.features['rsi'][-1] > 70: # Overbought condition
|
||||
reward += 1.5 # Bonus for selling at overbought levels
|
||||
|
||||
# Check if we're selling in a clear downtrend (good)
|
||||
if self.is_downtrend():
|
||||
reward += 1.0 # Bonus for selling in downtrend
|
||||
elif self.is_uptrend():
|
||||
reward -= 0.25 # Reduced penalty for selling in uptrend
|
||||
else:
|
||||
reward += 0.1 # Small reward for opening a position
|
||||
reward += 0.2 # Small reward for opening a position
|
||||
|
||||
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||
|
||||
@ -1714,15 +1745,32 @@ class Agent:
|
||||
sample = random.random()
|
||||
|
||||
if training:
|
||||
# Epsilon decay
|
||||
# More aggressive epsilon decay for faster exploitation
|
||||
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
|
||||
np.exp(-1. * self.steps_done / EPSILON_DECAY)
|
||||
np.exp(-1.5 * self.steps_done / EPSILON_DECAY) # Increased decay factor
|
||||
self.steps_done += 1
|
||||
|
||||
# Lower threshold for exploration, especially in live trading
|
||||
if not training:
|
||||
# In live trading, be much more aggressive with exploitation
|
||||
self.epsilon = max(EPSILON_END, self.epsilon * 0.95)
|
||||
|
||||
if sample > self.epsilon or not training:
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).to(self.device)
|
||||
action_values = self.policy_net(state_tensor)
|
||||
|
||||
# Add temperature-based sampling for more aggressive actions
|
||||
# when the model is confident (higher action differences)
|
||||
if not training: # More aggressive in live trading
|
||||
values = action_values.cpu().numpy()
|
||||
max_value = np.max(values)
|
||||
value_diff = max_value - np.mean(values)
|
||||
|
||||
# If there's a clear best action, always take it
|
||||
if value_diff > 0.5:
|
||||
return action_values.max(1)[1].item()
|
||||
|
||||
return action_values.max(1)[1].item()
|
||||
else:
|
||||
return random.randrange(self.action_size)
|
||||
@ -2761,7 +2809,7 @@ async def process_websocket_ticks(websocket, env, agent=None, demo=True, timefra
|
||||
continue
|
||||
|
||||
# Convert timestamp to datetime
|
||||
tick_time = datetime.fromtimestamp(timestamp / 1000)
|
||||
tick_time = datetime.datetime.fromtimestamp(timestamp / 1000)
|
||||
|
||||
# For 1-minute candles, track the minute
|
||||
if timeframe == "1m":
|
||||
@ -2856,6 +2904,8 @@ async def main():
|
||||
help='Path to model file for evaluation or live trading')
|
||||
parser.add_argument('--use-websocket', action='store_true',
|
||||
help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)')
|
||||
parser.add_argument('--dashboard', action='store_true',
|
||||
help='Enable Dash dashboard visualization for real-time trading')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2876,7 +2926,7 @@ async def main():
|
||||
|
||||
if args.mode == 'train':
|
||||
# Fetch initial data for training
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||
await env.fetch_initial_data(exchange, args.symbol,args.timeframe, 1000)
|
||||
|
||||
# Create agent with consistent parameters
|
||||
# Note: Using STATE_SIZE and action_size=4 for consistency
|
||||
@ -2942,7 +2992,8 @@ async def main():
|
||||
symbol=args.symbol,
|
||||
timeframe=args.timeframe,
|
||||
demo=demo_mode,
|
||||
leverage=args.leverage
|
||||
leverage=args.leverage,
|
||||
use_dashboard=args.dashboard
|
||||
)
|
||||
else:
|
||||
logger.info("Using CCXT for real-time data")
|
||||
@ -3048,7 +3099,7 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
||||
logger.error(f"Error creating chart: {str(e)}")
|
||||
return None
|
||||
|
||||
async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
|
||||
async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50, use_dashboard=False):
|
||||
"""Run the trading bot in live mode using Binance WebSocket for real-time data
|
||||
|
||||
Args:
|
||||
@ -3058,6 +3109,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
timeframe: The candlestick timeframe (e.g., "1m")
|
||||
demo: Whether to run in demo mode (paper trading)
|
||||
leverage: The leverage to use for trading
|
||||
use_dashboard: Whether to display the real-time dashboard
|
||||
|
||||
Returns:
|
||||
None
|
||||
@ -3075,9 +3127,25 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
# Initialize TensorBoard for monitoring
|
||||
if not hasattr(agent, 'writer') or agent.writer is None:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}')
|
||||
|
||||
# Initialize Dash dashboard if enabled
|
||||
dashboard = None
|
||||
if use_dashboard:
|
||||
try:
|
||||
dashboard = TradingDashboard(symbol)
|
||||
dashboard_started = dashboard.start() # Start the dashboard in a separate thread
|
||||
if dashboard_started:
|
||||
logger.info(f"Trading dashboard enabled at http://localhost:8060")
|
||||
else:
|
||||
logger.warning("Failed to start trading dashboard, continuing without visualization")
|
||||
dashboard = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing dashboard: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
dashboard = None
|
||||
|
||||
# Track performance metrics
|
||||
trades_count = 0
|
||||
winning_trades = 0
|
||||
@ -3088,7 +3156,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
|
||||
# Create directory for trade logs
|
||||
os.makedirs('trade_logs', exist_ok=True)
|
||||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
trade_log_path = f'trade_logs/trades_ws_{current_time}.csv'
|
||||
with open(trade_log_path, 'w') as f:
|
||||
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
||||
@ -3109,6 +3177,10 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
# Reset environment with historical data
|
||||
env.reset()
|
||||
|
||||
# Update dashboard with initial data if enabled
|
||||
if dashboard:
|
||||
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
|
||||
|
||||
# Initialize futures trading if not in demo mode
|
||||
exchange = None
|
||||
if not demo:
|
||||
@ -3144,7 +3216,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
total_profit += env.last_trade_profit
|
||||
|
||||
# Log trade details
|
||||
current_time = datetime.now().isoformat()
|
||||
current_time = datetime.datetime.now().isoformat()
|
||||
action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE"
|
||||
with open(trade_log_path, 'a') as f:
|
||||
f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n")
|
||||
@ -3184,6 +3256,10 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
"""
|
||||
agent.writer.add_text('Performance', performance_text, step_counter)
|
||||
|
||||
# Update the dashboard with latest data if enabled
|
||||
if dashboard:
|
||||
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
|
||||
|
||||
prev_position = env.position
|
||||
|
||||
# Sleep for a short time to prevent CPU hogging
|
||||
@ -3231,6 +3307,252 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
|
||||
if 'exchange' in locals() and exchange:
|
||||
await exchange.close()
|
||||
|
||||
def ensure_pytorch_compatibility():
|
||||
"""Check and fix common PyTorch compatibility issues"""
|
||||
try:
|
||||
import torch.serialization
|
||||
import pickle
|
||||
|
||||
# Register safe pickles to handle the numpy scalar warning
|
||||
if hasattr(torch.serialization, 'add_safe_globals'):
|
||||
torch.serialization.add_safe_globals([('numpy._core.multiarray.scalar', np.ndarray)])
|
||||
torch.serialization.add_safe_globals([('numpy.core.multiarray.scalar', np.ndarray)])
|
||||
torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])
|
||||
torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar'])
|
||||
|
||||
logger.info("PyTorch safe globals registered for compatibility")
|
||||
else:
|
||||
logger.warning("PyTorch serialization module doesn't have add_safe_globals method")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"PyTorch compatibility check failed: {e}")
|
||||
|
||||
|
||||
class TradingDashboard:
|
||||
"""Dashboard for visualizing trading activity with Dash"""
|
||||
|
||||
def __init__(self, symbol="ETH/USDT"):
|
||||
self.symbol = symbol
|
||||
self.env = None
|
||||
self.candles = []
|
||||
self.trade_signals = []
|
||||
|
||||
# Create Dash app
|
||||
self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
||||
|
||||
# Create basic layout
|
||||
self.app.layout = html.Div([
|
||||
# Store components for data
|
||||
html.Div(id='candle-store', style={'display': 'none'}),
|
||||
html.Div(id='signal-store', style={'display': 'none'}),
|
||||
|
||||
# Header
|
||||
html.H1(f"Trading Dashboard - {symbol}", style={'textAlign': 'center'}),
|
||||
|
||||
# Main content
|
||||
html.Div([
|
||||
# Chart
|
||||
html.Div([
|
||||
dcc.Graph(id='candlestick-chart', style={'height': '70vh'}),
|
||||
dcc.Interval(id='interval-component', interval=5*1000, n_intervals=0)
|
||||
], style={'width': '70%', 'display': 'inline-block'}),
|
||||
|
||||
# Trading info
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H3("Account Info"),
|
||||
html.Div(id='account-info')
|
||||
]),
|
||||
html.Div([
|
||||
html.H3("Recent Trades"),
|
||||
html.Div(id='recent-trades')
|
||||
])
|
||||
], style={'width': '30%', 'display': 'inline-block', 'verticalAlign': 'top'})
|
||||
])
|
||||
])
|
||||
|
||||
# Setup callbacks
|
||||
self._setup_callbacks()
|
||||
|
||||
# Thread for running the server
|
||||
self.thread = None
|
||||
self.is_running = False
|
||||
|
||||
def _setup_callbacks(self):
|
||||
@self.app.callback(
|
||||
Output('candlestick-chart', 'figure'),
|
||||
[Input('interval-component', 'n_intervals'),
|
||||
Input('candle-store', 'children'),
|
||||
Input('signal-store', 'children')]
|
||||
)
|
||||
def update_chart(n, candles_json, signals_json):
|
||||
# Parse JSON data
|
||||
candles = json.loads(candles_json) if candles_json else []
|
||||
signals = json.loads(signals_json) if signals_json else []
|
||||
|
||||
# Create figure with subplots
|
||||
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
|
||||
vertical_spacing=0.1, row_heights=[0.7, 0.3])
|
||||
|
||||
if candles:
|
||||
# Convert to dataframe
|
||||
df = pd.DataFrame(candles[-100:]) # Show last 100 candles
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
|
||||
# Add candlestick trace
|
||||
fig.add_trace(
|
||||
go.Candlestick(
|
||||
x=df['timestamp'],
|
||||
open=df['open'],
|
||||
high=df['high'],
|
||||
low=df['low'],
|
||||
close=df['close'],
|
||||
name='Price'
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add volume trace
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=df['timestamp'],
|
||||
y=df['volume'],
|
||||
name='Volume'
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Add trade signals
|
||||
for signal in signals:
|
||||
if signal['timestamp'] >= df['timestamp'].iloc[0].timestamp() * 1000:
|
||||
signal_time = pd.to_datetime(signal['timestamp'], unit='ms')
|
||||
marker_color = 'green' if signal['type'] == 'buy' else 'red' if signal['type'] == 'sell' else 'orange'
|
||||
marker_symbol = 'triangle-up' if signal['type'] == 'buy' else 'triangle-down' if signal['type'] == 'sell' else 'circle'
|
||||
|
||||
# Add marker for signal
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[signal_time],
|
||||
y=[signal['price']],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
color=marker_color,
|
||||
size=12,
|
||||
symbol=marker_symbol
|
||||
),
|
||||
name=signal['type'].capitalize(),
|
||||
showlegend=False
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
title=f'{self.symbol} Trading Chart',
|
||||
xaxis_rangeslider_visible=False,
|
||||
template='plotly_dark'
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
@self.app.callback(
|
||||
[Output('account-info', 'children'),
|
||||
Output('recent-trades', 'children')],
|
||||
[Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_account_info(n):
|
||||
if not self.env:
|
||||
return "No data available", "No trades available"
|
||||
|
||||
# Account info
|
||||
account_info = html.Div([
|
||||
html.P(f"Balance: ${self.env.balance:.2f}"),
|
||||
html.P(f"PnL: ${self.env.total_pnl:.2f}",
|
||||
style={'color': 'green' if self.env.total_pnl > 0 else 'red' if self.env.total_pnl < 0 else 'white'}),
|
||||
html.P(f"Position: {self.env.position.upper()}")
|
||||
])
|
||||
|
||||
# Recent trades
|
||||
if hasattr(self.env, 'trades') and self.env.trades:
|
||||
# Get last 5 trades
|
||||
recent_trades = []
|
||||
for trade in reversed(self.env.trades[-5:]):
|
||||
trade_card = html.Div([
|
||||
html.P(f"{trade['action'].upper()} at ${trade['price']:.2f}"),
|
||||
html.P(f"PnL: ${trade['pnl']:.2f}",
|
||||
style={'color': 'green' if trade['pnl'] > 0 else 'red' if trade['pnl'] < 0 else 'white'})
|
||||
], style={'border': '1px solid #ddd', 'padding': '10px', 'margin-bottom': '5px'})
|
||||
recent_trades.append(trade_card)
|
||||
else:
|
||||
recent_trades = [html.P("No trades yet")]
|
||||
|
||||
return account_info, recent_trades
|
||||
|
||||
def update_data(self, env=None, candles=None, trade_signals=None):
|
||||
"""Update dashboard data"""
|
||||
if env:
|
||||
self.env = env
|
||||
|
||||
if candles:
|
||||
self.candles = candles
|
||||
|
||||
if trade_signals:
|
||||
self.trade_signals = trade_signals
|
||||
|
||||
# Update store components
|
||||
if hasattr(self.app, 'layout'):
|
||||
self.app.layout.children[0].children = json.dumps(self.candles)
|
||||
self.app.layout.children[1].children = json.dumps(self.trade_signals)
|
||||
|
||||
def start(self, host='localhost', port=8060):
|
||||
"""Start the dashboard server in a separate thread"""
|
||||
if not self.is_running:
|
||||
# First check if the port is already in use
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
port_available = False
|
||||
|
||||
# Try the initial port and a few alternatives if needed
|
||||
for attempt_port in range(port, port + 10):
|
||||
try:
|
||||
sock.bind((host, attempt_port))
|
||||
port_available = True
|
||||
port = attempt_port
|
||||
break
|
||||
except socket.error:
|
||||
logger.warning(f"Port {attempt_port} is already in use")
|
||||
sock.close()
|
||||
|
||||
if not port_available:
|
||||
logger.error("Could not find an available port for dashboard")
|
||||
return False
|
||||
|
||||
# Create and start the thread
|
||||
self.thread = Thread(target=self._run_server, args=(host, port))
|
||||
self.thread.daemon = True # This ensures the thread will exit when the main program does
|
||||
self.thread.start()
|
||||
self.is_running = True
|
||||
logger.info(f"Trading dashboard started at http://{host}:{port}")
|
||||
|
||||
# Verify the thread actually started
|
||||
if not self.thread.is_alive():
|
||||
logger.error("Dashboard thread failed to start")
|
||||
return False
|
||||
|
||||
# Wait a short time to let the server initialize
|
||||
time.sleep(1.0)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _run_server(self, host, port):
|
||||
"""Run the Dash server"""
|
||||
try:
|
||||
logger.info(f"Starting Dash server on {host}:{port}")
|
||||
self.app.run_server(debug=False, host=host, port=port, use_reloader=False, threaded=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running dashboard server: {e}")
|
||||
self.is_running = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user