diff --git a/main.py b/main.py index b5c40fc..77af50b 100644 --- a/main.py +++ b/main.py @@ -29,7 +29,15 @@ from PIL import Image import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime +from realtime import BinanceWebSocket, BinanceHistoricalData from datetime import datetime as dt +# Add Dash-related imports +import dash +from dash import html, dcc, callback_context +from dash.dependencies import Input, Output, State +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from threading import Thread # Configure logging logging.basicConfig( @@ -2644,9 +2652,173 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): return data except Exception as e: - logger.error(f"Error fetching OHLCV data: {e}") + logger.error(f"Failed to fetch OHLCV data: {e}") return [] +async def initialize_websocket_data_stream(symbol="ETH/USDT", timeframe="1m"): + """Initialize a WebSocket connection for real-time trading data + + Args: + symbol: Trading pair symbol (e.g., "ETH/USDT") + timeframe: Timeframe for candle aggregation (e.g., "1m") + + Returns: + Tuple of (websocket, candle_data) where websocket is the BinanceWebSocket instance + and candle_data is a dict to track ongoing candle formation + """ + try: + # Initialize historical data handler to get initial data + historical_data = BinanceHistoricalData() + + # Convert timeframe to seconds for historical data + if timeframe == "1m": + interval_seconds = 60 + elif timeframe == "5m": + interval_seconds = 300 + elif timeframe == "15m": + interval_seconds = 900 + elif timeframe == "1h": + interval_seconds = 3600 + else: + interval_seconds = 60 # Default to 1m + + # Fetch initial historical data + initial_data = historical_data.get_historical_candles( + symbol=symbol, + interval_seconds=interval_seconds, + limit=1000 # Get 1000 candles for good history + ) + + # Convert pandas DataFrame to list of dictionaries for our environment + initial_candles = [] + if not initial_data.empty: + for _, row in initial_data.iterrows(): + candle = { + 'timestamp': int(row['timestamp'].timestamp() * 1000), + 'open': float(row['open']), + 'high': float(row['high']), + 'low': float(row['low']), + 'close': float(row['close']), + 'volume': float(row['volume']) + } + initial_candles.append(candle) + + logger.info(f"Loaded {len(initial_candles)} historical candles") + else: + logger.warning("No historical data fetched") + + # Initialize WebSocket for real-time data + binance_ws = BinanceWebSocket(symbol.replace('/', '')) + await binance_ws.connect() + + # Track the current candle data + current_minute = None + current_candle = None + + logger.info(f"WebSocket for {symbol} initialized successfully") + return binance_ws, initial_candles + + except Exception as e: + logger.error(f"Failed to initialize WebSocket data stream: {e}") + logger.error(traceback.format_exc()) + return None, [] + +async def process_websocket_ticks(websocket, env, agent=None, demo=True, timeframe="1m"): + """Process real-time ticks from WebSocket and aggregate them into candles + + Args: + websocket: BinanceWebSocket instance + env: TradingEnvironment instance + agent: Agent instance (optional, for live trading) + demo: Whether to run in demo mode + timeframe: Timeframe for candle aggregation + """ + # Initialize variables for candle aggregation + current_candle = None + current_minute = None + trades_count = 0 + step_counter = 0 + + try: + logger.info("Starting WebSocket tick processing...") + + while websocket.running: + # Get the next tick from WebSocket + tick = await websocket.receive() + + if tick is None: + # No data received, wait and try again + await asyncio.sleep(0.1) + continue + + # Extract data from tick + timestamp = tick.get('timestamp') + price = tick.get('price') + volume = tick.get('volume') + + if timestamp is None or price is None: + logger.warning(f"Invalid tick data received: {tick}") + continue + + # Convert timestamp to datetime + tick_time = datetime.fromtimestamp(timestamp / 1000) + + # For 1-minute candles, track the minute + if timeframe == "1m": + tick_minute = tick_time.replace(second=0, microsecond=0) + + # If this is a new minute, close the current candle and start a new one + if current_minute is None or tick_minute > current_minute: + # If there was a previous candle, add it to the environment + if current_candle is not None: + # Add the candle to the environment + env.add_data(current_candle) + + # Process trading decisions if agent is provided + if agent is not None: + state = env.get_state() + action = agent.select_action(state, training=False) + + # Execute action in environment + next_state, reward, done, info = env.step(action) + + # Log trading activity + action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" + logger.info(f"Step {step_counter}: Action {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}") + step_counter += 1 + + # Start a new candle + current_minute = tick_minute + current_candle = { + 'timestamp': int(current_minute.timestamp() * 1000), + 'open': price, + 'high': price, + 'low': price, + 'close': price, + 'volume': volume + } + logger.debug(f"Started new candle at {current_minute}") + else: + # Update the current candle + current_candle['high'] = max(current_candle['high'], price) + current_candle['low'] = min(current_candle['low'], price) + current_candle['close'] = price + current_candle['volume'] += volume + + # For other timeframes, implement similar logic + # ... + + except asyncio.CancelledError: + logger.info("WebSocket processing canceled") + except Exception as e: + logger.error(f"Error in WebSocket tick processing: {e}") + logger.error(traceback.format_exc()) + finally: + # Make sure to close the WebSocket + if websocket: + await websocket.close() + logger.info("WebSocket connection closed") + # Add this near the top of the file, after imports def ensure_pytorch_compatibility(): """Ensure compatibility with PyTorch 2.6+ for model loading""" @@ -2682,6 +2854,8 @@ async def main(): help='Leverage for futures trading') parser.add_argument('--model', type=str, default=None, help='Path to model file for evaluation or live trading') + parser.add_argument('--use-websocket', action='store_true', + help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)') args = parser.parse_args() @@ -2760,15 +2934,27 @@ async def main(): logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe") logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x") - await live_trading( - agent=agent, - env=env, - exchange=exchange, - symbol=args.symbol, - timeframe=args.timeframe, - demo=demo_mode, - leverage=args.leverage - ) + if args.use_websocket: + logger.info("Using Binance WebSocket for real-time data") + await live_trading_with_websocket( + agent=agent, + env=env, + symbol=args.symbol, + timeframe=args.timeframe, + demo=demo_mode, + leverage=args.leverage + ) + else: + logger.info("Using CCXT for real-time data") + await live_trading( + agent=agent, + env=env, + exchange=exchange, + symbol=args.symbol, + timeframe=args.timeframe, + demo=demo_mode, + leverage=args.leverage + ) except Exception as e: logger.error(f"Error in main function: {e}") @@ -2862,6 +3048,189 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""): logger.error(f"Error creating chart: {str(e)}") return None +async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): + """Run the trading bot in live mode using Binance WebSocket for real-time data + + Args: + agent: The trading agent to use for decision making + env: The trading environment + symbol: The trading pair symbol (e.g., "ETH/USDT") + timeframe: The candlestick timeframe (e.g., "1m") + demo: Whether to run in demo mode (paper trading) + leverage: The leverage to use for trading + + Returns: + None + """ + logger.info(f"Starting live trading with WebSocket for {symbol} on {timeframe} timeframe") + logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") + + # If not demo mode, confirm with user before starting live trading + if not demo: + confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") + if confirmation != "CONFIRM": + logger.info("Live trading canceled by user") + return + + # Initialize TensorBoard for monitoring + if not hasattr(agent, 'writer') or agent.writer is None: + from torch.utils.tensorboard import SummaryWriter + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}') + + # Track performance metrics + trades_count = 0 + winning_trades = 0 + total_profit = 0 + max_drawdown = 0 + peak_balance = env.balance + step_counter = 0 + + # Create directory for trade logs + os.makedirs('trade_logs', exist_ok=True) + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + trade_log_path = f'trade_logs/trades_ws_{current_time}.csv' + with open(trade_log_path, 'w') as f: + f.write("timestamp,action,price,position_size,balance,pnl\n") + + try: + # Initialize WebSocket connection and get historical data + websocket, initial_candles = await initialize_websocket_data_stream(symbol, timeframe) + + if websocket is None or not initial_candles: + logger.error("Failed to initialize WebSocket data stream") + return + + # Load initial historical data into the environment + logger.info(f"Loading {len(initial_candles)} initial candles into environment") + for candle in initial_candles: + env.add_data(candle) + + # Reset environment with historical data + env.reset() + + # Initialize futures trading if not in demo mode + exchange = None + if not demo: + # Import ccxt for exchange initialization + import ccxt.async_support as ccxt_async + + # Initialize exchange for order execution + exchange = await initialize_exchange() + if exchange: + try: + await env.initialize_futures(exchange) + logger.info(f"Futures trading initialized with {leverage}x leverage") + except Exception as e: + logger.error(f"Failed to initialize futures trading: {str(e)}") + logger.info("Falling back to demo mode for safety") + demo = True + + # Start WebSocket processing in the background + websocket_task = asyncio.create_task( + process_websocket_ticks(websocket, env, agent, demo, timeframe) + ) + + # Main tracking loop + prev_position = 'flat' + while True: + try: + # Check if position has changed + if env.position != prev_position: + trades_count += 1 + if hasattr(env, 'last_trade_profit') and env.last_trade_profit > 0: + winning_trades += 1 + if hasattr(env, 'last_trade_profit'): + total_profit += env.last_trade_profit + + # Log trade details + current_time = datetime.now().isoformat() + action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE" + with open(trade_log_path, 'a') as f: + f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n") + + logger.info(f"Trade executed: {action_name} at ${env.current_price:.2f}, PnL: ${getattr(env, 'last_trade_profit', 0):.2f}") + + # Update performance metrics + if env.balance > peak_balance: + peak_balance = env.balance + current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 + if current_drawdown > max_drawdown: + max_drawdown = current_drawdown + + # Update TensorBoard metrics + step_counter += 1 + if step_counter % 10 == 0: # Update every 10 steps + agent.writer.add_scalar('Live/Balance', env.balance, step_counter) + agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) + agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) + + # Update chart visualization + if step_counter % 30 == 0 or env.position != prev_position: + agent.add_chart_to_tensorboard(env, step_counter) + + # Log performance summary + if trades_count > 0: + win_rate = (winning_trades / trades_count) * 100 + agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) + + performance_text = f""" + **Live Trading Performance** + Balance: ${env.balance:.2f} + Total PnL: ${env.total_pnl:.2f} + Trades: {trades_count} + Win Rate: {win_rate:.1f}% + Max Drawdown: {max_drawdown*100:.1f}% + """ + agent.writer.add_text('Performance', performance_text, step_counter) + + prev_position = env.position + + # Sleep for a short time to prevent CPU hogging + await asyncio.sleep(1) + + except Exception as e: + logger.error(f"Error in live trading monitor loop: {str(e)}") + logger.error(traceback.format_exc()) + await asyncio.sleep(10) # Wait longer after an error + + except KeyboardInterrupt: + logger.info("Live trading stopped by user") + + # Cancel the WebSocket task + if 'websocket_task' in locals() and not websocket_task.done(): + websocket_task.cancel() + try: + await websocket_task + except asyncio.CancelledError: + pass + + # Close the exchange connection if it exists + if exchange: + await exchange.close() + + # Final performance report + if trades_count > 0: + win_rate = (winning_trades / trades_count) * 100 + logger.info(f"Trading session summary:") + logger.info(f"Total trades: {trades_count}") + logger.info(f"Win rate: {win_rate:.1f}%") + logger.info(f"Final balance: ${env.balance:.2f}") + logger.info(f"Total profit: ${total_profit:.2f}") + + except Exception as e: + logger.error(f"Critical error in live trading: {str(e)}") + logger.error(traceback.format_exc()) + + finally: + # Make sure to close WebSocket + if 'websocket' in locals() and websocket: + await websocket.close() + + # Close the exchange connection if it exists + if 'exchange' in locals() and exchange: + await exchange.close() + if __name__ == "__main__": try: asyncio.run(main())