From 1a1c410922c5f76b690222a8e5619cf60d849618 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 19 Mar 2025 04:18:55 +0200 Subject: [PATCH] realtime data in main --- main.py | 360 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 341 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 77af50b..56cbbc9 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,7 @@ from dash.dependencies import Input, Output, State import plotly.graph_objects as go from plotly.subplots import make_subplots from threading import Thread +import socket # Configure logging logging.basicConfig( @@ -971,12 +972,12 @@ class TradingEnvironment: def calculate_reward(self, action): - """Calculate reward for the given action with improved penalties for losing trades""" + """Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals""" reward = 0 # Base reward for actions if action == 0: # HOLD - reward = -0.01 # Small penalty for doing nothing + reward = -0.05 # Increased penalty for doing nothing to encourage more trading elif action == 1: # BUY/LONG if self.position == 'flat': @@ -990,13 +991,26 @@ class TradingEnvironment: # Check if this is an optimal buy point (bottom) current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: - reward += 2.0 # Bonus for buying at a bottom + reward += 3.0 # Increased bonus for buying at a bottom + + # Check for volume spike (indicating potential big movement) + if len(self.features['volume']) > 5: + avg_volume = np.mean(self.features['volume'][-5:-1]) + current_volume = self.features['volume'][-1] + if current_volume > avg_volume * 1.5: + reward += 2.0 # Bonus for entering during high volume + + # Check for price action signals + if self.features['rsi'][-1] < 30: # Oversold condition + reward += 1.5 # Bonus for buying at oversold levels + + # Check if we're buying in a clear uptrend (good) + if self.is_uptrend(): + reward += 1.0 # Bonus for buying in uptrend + elif self.is_downtrend(): + reward -= 0.25 # Reduced penalty for buying in downtrend else: - # Check if we're buying in a downtrend (bad) - if self.is_downtrend(): - reward -= 0.5 # Penalty for buying in downtrend - else: - reward += 0.1 # Small reward for opening a position + reward += 0.2 # Small reward for opening a position logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") @@ -1066,9 +1080,26 @@ class TradingEnvironment: # Check if this is an optimal sell point (top) current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: - reward += 2.0 # Bonus for selling at a top + reward += 3.0 # Increased bonus for selling at a top + + # Check for volume spike + if len(self.features['volume']) > 5: + avg_volume = np.mean(self.features['volume'][-5:-1]) + current_volume = self.features['volume'][-1] + if current_volume > avg_volume * 1.5: + reward += 2.0 # Bonus for entering during high volume + + # Check for price action signals + if self.features['rsi'][-1] > 70: # Overbought condition + reward += 1.5 # Bonus for selling at overbought levels + + # Check if we're selling in a clear downtrend (good) + if self.is_downtrend(): + reward += 1.0 # Bonus for selling in downtrend + elif self.is_uptrend(): + reward -= 0.25 # Reduced penalty for selling in uptrend else: - reward += 0.1 # Small reward for opening a position + reward += 0.2 # Small reward for opening a position logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") @@ -1714,15 +1745,32 @@ class Agent: sample = random.random() if training: - # Epsilon decay + # More aggressive epsilon decay for faster exploitation self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \ - np.exp(-1. * self.steps_done / EPSILON_DECAY) + np.exp(-1.5 * self.steps_done / EPSILON_DECAY) # Increased decay factor self.steps_done += 1 + # Lower threshold for exploration, especially in live trading + if not training: + # In live trading, be much more aggressive with exploitation + self.epsilon = max(EPSILON_END, self.epsilon * 0.95) + if sample > self.epsilon or not training: with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) action_values = self.policy_net(state_tensor) + + # Add temperature-based sampling for more aggressive actions + # when the model is confident (higher action differences) + if not training: # More aggressive in live trading + values = action_values.cpu().numpy() + max_value = np.max(values) + value_diff = max_value - np.mean(values) + + # If there's a clear best action, always take it + if value_diff > 0.5: + return action_values.max(1)[1].item() + return action_values.max(1)[1].item() else: return random.randrange(self.action_size) @@ -2761,7 +2809,7 @@ async def process_websocket_ticks(websocket, env, agent=None, demo=True, timefra continue # Convert timestamp to datetime - tick_time = datetime.fromtimestamp(timestamp / 1000) + tick_time = datetime.datetime.fromtimestamp(timestamp / 1000) # For 1-minute candles, track the minute if timeframe == "1m": @@ -2856,6 +2904,8 @@ async def main(): help='Path to model file for evaluation or live trading') parser.add_argument('--use-websocket', action='store_true', help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)') + parser.add_argument('--dashboard', action='store_true', + help='Enable Dash dashboard visualization for real-time trading') args = parser.parse_args() @@ -2876,7 +2926,7 @@ async def main(): if args.mode == 'train': # Fetch initial data for training - await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) + await env.fetch_initial_data(exchange, args.symbol,args.timeframe, 1000) # Create agent with consistent parameters # Note: Using STATE_SIZE and action_size=4 for consistency @@ -2942,7 +2992,8 @@ async def main(): symbol=args.symbol, timeframe=args.timeframe, demo=demo_mode, - leverage=args.leverage + leverage=args.leverage, + use_dashboard=args.dashboard ) else: logger.info("Using CCXT for real-time data") @@ -3048,7 +3099,7 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""): logger.error(f"Error creating chart: {str(e)}") return None -async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): +async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50, use_dashboard=False): """Run the trading bot in live mode using Binance WebSocket for real-time data Args: @@ -3058,6 +3109,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" timeframe: The candlestick timeframe (e.g., "1m") demo: Whether to run in demo mode (paper trading) leverage: The leverage to use for trading + use_dashboard: Whether to display the real-time dashboard Returns: None @@ -3075,9 +3127,25 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" # Initialize TensorBoard for monitoring if not hasattr(agent, 'writer') or agent.writer is None: from torch.utils.tensorboard import SummaryWriter - current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}') + # Initialize Dash dashboard if enabled + dashboard = None + if use_dashboard: + try: + dashboard = TradingDashboard(symbol) + dashboard_started = dashboard.start() # Start the dashboard in a separate thread + if dashboard_started: + logger.info(f"Trading dashboard enabled at http://localhost:8060") + else: + logger.warning("Failed to start trading dashboard, continuing without visualization") + dashboard = None + except Exception as e: + logger.error(f"Error initializing dashboard: {e}") + logger.error(traceback.format_exc()) + dashboard = None + # Track performance metrics trades_count = 0 winning_trades = 0 @@ -3088,7 +3156,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" # Create directory for trade logs os.makedirs('trade_logs', exist_ok=True) - current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") trade_log_path = f'trade_logs/trades_ws_{current_time}.csv' with open(trade_log_path, 'w') as f: f.write("timestamp,action,price,position_size,balance,pnl\n") @@ -3109,6 +3177,10 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" # Reset environment with historical data env.reset() + # Update dashboard with initial data if enabled + if dashboard: + dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals) + # Initialize futures trading if not in demo mode exchange = None if not demo: @@ -3144,7 +3216,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" total_profit += env.last_trade_profit # Log trade details - current_time = datetime.now().isoformat() + current_time = datetime.datetime.now().isoformat() action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE" with open(trade_log_path, 'a') as f: f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n") @@ -3184,6 +3256,10 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" """ agent.writer.add_text('Performance', performance_text, step_counter) + # Update the dashboard with latest data if enabled + if dashboard: + dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals) + prev_position = env.position # Sleep for a short time to prevent CPU hogging @@ -3231,6 +3307,252 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe=" if 'exchange' in locals() and exchange: await exchange.close() +def ensure_pytorch_compatibility(): + """Check and fix common PyTorch compatibility issues""" + try: + import torch.serialization + import pickle + + # Register safe pickles to handle the numpy scalar warning + if hasattr(torch.serialization, 'add_safe_globals'): + torch.serialization.add_safe_globals([('numpy._core.multiarray.scalar', np.ndarray)]) + torch.serialization.add_safe_globals([('numpy.core.multiarray.scalar', np.ndarray)]) + torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar']) + torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar']) + + logger.info("PyTorch safe globals registered for compatibility") + else: + logger.warning("PyTorch serialization module doesn't have add_safe_globals method") + + except Exception as e: + logger.warning(f"PyTorch compatibility check failed: {e}") + + +class TradingDashboard: + """Dashboard for visualizing trading activity with Dash""" + + def __init__(self, symbol="ETH/USDT"): + self.symbol = symbol + self.env = None + self.candles = [] + self.trade_signals = [] + + # Create Dash app + self.app = dash.Dash(__name__, suppress_callback_exceptions=True) + + # Create basic layout + self.app.layout = html.Div([ + # Store components for data + html.Div(id='candle-store', style={'display': 'none'}), + html.Div(id='signal-store', style={'display': 'none'}), + + # Header + html.H1(f"Trading Dashboard - {symbol}", style={'textAlign': 'center'}), + + # Main content + html.Div([ + # Chart + html.Div([ + dcc.Graph(id='candlestick-chart', style={'height': '70vh'}), + dcc.Interval(id='interval-component', interval=5*1000, n_intervals=0) + ], style={'width': '70%', 'display': 'inline-block'}), + + # Trading info + html.Div([ + html.Div([ + html.H3("Account Info"), + html.Div(id='account-info') + ]), + html.Div([ + html.H3("Recent Trades"), + html.Div(id='recent-trades') + ]) + ], style={'width': '30%', 'display': 'inline-block', 'verticalAlign': 'top'}) + ]) + ]) + + # Setup callbacks + self._setup_callbacks() + + # Thread for running the server + self.thread = None + self.is_running = False + + def _setup_callbacks(self): + @self.app.callback( + Output('candlestick-chart', 'figure'), + [Input('interval-component', 'n_intervals'), + Input('candle-store', 'children'), + Input('signal-store', 'children')] + ) + def update_chart(n, candles_json, signals_json): + # Parse JSON data + candles = json.loads(candles_json) if candles_json else [] + signals = json.loads(signals_json) if signals_json else [] + + # Create figure with subplots + fig = make_subplots(rows=2, cols=1, shared_xaxes=True, + vertical_spacing=0.1, row_heights=[0.7, 0.3]) + + if candles: + # Convert to dataframe + df = pd.DataFrame(candles[-100:]) # Show last 100 candles + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + + # Add candlestick trace + fig.add_trace( + go.Candlestick( + x=df['timestamp'], + open=df['open'], + high=df['high'], + low=df['low'], + close=df['close'], + name='Price' + ), + row=1, col=1 + ) + + # Add volume trace + fig.add_trace( + go.Bar( + x=df['timestamp'], + y=df['volume'], + name='Volume' + ), + row=2, col=1 + ) + + # Add trade signals + for signal in signals: + if signal['timestamp'] >= df['timestamp'].iloc[0].timestamp() * 1000: + signal_time = pd.to_datetime(signal['timestamp'], unit='ms') + marker_color = 'green' if signal['type'] == 'buy' else 'red' if signal['type'] == 'sell' else 'orange' + marker_symbol = 'triangle-up' if signal['type'] == 'buy' else 'triangle-down' if signal['type'] == 'sell' else 'circle' + + # Add marker for signal + fig.add_trace( + go.Scatter( + x=[signal_time], + y=[signal['price']], + mode='markers', + marker=dict( + color=marker_color, + size=12, + symbol=marker_symbol + ), + name=signal['type'].capitalize(), + showlegend=False + ), + row=1, col=1 + ) + + # Update layout + fig.update_layout( + title=f'{self.symbol} Trading Chart', + xaxis_rangeslider_visible=False, + template='plotly_dark' + ) + + return fig + + @self.app.callback( + [Output('account-info', 'children'), + Output('recent-trades', 'children')], + [Input('interval-component', 'n_intervals')] + ) + def update_account_info(n): + if not self.env: + return "No data available", "No trades available" + + # Account info + account_info = html.Div([ + html.P(f"Balance: ${self.env.balance:.2f}"), + html.P(f"PnL: ${self.env.total_pnl:.2f}", + style={'color': 'green' if self.env.total_pnl > 0 else 'red' if self.env.total_pnl < 0 else 'white'}), + html.P(f"Position: {self.env.position.upper()}") + ]) + + # Recent trades + if hasattr(self.env, 'trades') and self.env.trades: + # Get last 5 trades + recent_trades = [] + for trade in reversed(self.env.trades[-5:]): + trade_card = html.Div([ + html.P(f"{trade['action'].upper()} at ${trade['price']:.2f}"), + html.P(f"PnL: ${trade['pnl']:.2f}", + style={'color': 'green' if trade['pnl'] > 0 else 'red' if trade['pnl'] < 0 else 'white'}) + ], style={'border': '1px solid #ddd', 'padding': '10px', 'margin-bottom': '5px'}) + recent_trades.append(trade_card) + else: + recent_trades = [html.P("No trades yet")] + + return account_info, recent_trades + + def update_data(self, env=None, candles=None, trade_signals=None): + """Update dashboard data""" + if env: + self.env = env + + if candles: + self.candles = candles + + if trade_signals: + self.trade_signals = trade_signals + + # Update store components + if hasattr(self.app, 'layout'): + self.app.layout.children[0].children = json.dumps(self.candles) + self.app.layout.children[1].children = json.dumps(self.trade_signals) + + def start(self, host='localhost', port=8060): + """Start the dashboard server in a separate thread""" + if not self.is_running: + # First check if the port is already in use + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + port_available = False + + # Try the initial port and a few alternatives if needed + for attempt_port in range(port, port + 10): + try: + sock.bind((host, attempt_port)) + port_available = True + port = attempt_port + break + except socket.error: + logger.warning(f"Port {attempt_port} is already in use") + sock.close() + + if not port_available: + logger.error("Could not find an available port for dashboard") + return False + + # Create and start the thread + self.thread = Thread(target=self._run_server, args=(host, port)) + self.thread.daemon = True # This ensures the thread will exit when the main program does + self.thread.start() + self.is_running = True + logger.info(f"Trading dashboard started at http://{host}:{port}") + + # Verify the thread actually started + if not self.thread.is_alive(): + logger.error("Dashboard thread failed to start") + return False + + # Wait a short time to let the server initialize + time.sleep(1.0) + return True + return False + + def _run_server(self, host, port): + """Run the Dash server""" + try: + logger.info(f"Starting Dash server on {host}:{port}") + self.app.run_server(debug=False, host=host, port=port, use_reloader=False, threaded=True) + except Exception as e: + logger.error(f"Error running dashboard server: {e}") + self.is_running = False + + if __name__ == "__main__": try: asyncio.run(main())